Skip to content

Commit ca7e02c

Browse files
authored
Update LiveCELL inference scripts (#315)
* Update livecell inference
1 parent 8ec6df9 commit ca7e02c

File tree

8 files changed

+128
-60
lines changed

8 files changed

+128
-60
lines changed

finetuning/livecell/evaluation/evaluate_amg.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,10 @@
33

44
from micro_sam.evaluation.evaluation import run_evaluation
55
from micro_sam.evaluation.livecell import run_livecell_amg
6-
from util import DATA_ROOT, get_checkpoint, get_experiment_folder, get_pred_and_gt_paths
6+
from util import DATA_ROOT, get_pred_and_gt_paths
77

88

9-
def run_amg(name, model_type, checkpoint):
10-
if checkpoint is None:
11-
checkpoint, model_type = get_checkpoint(name)
12-
experiment_folder = get_experiment_folder(name)
9+
def run_amg(model_type, checkpoint, experiment_folder):
1310
input_folder = DATA_ROOT
1411
prediction_folder = run_livecell_amg(
1512
checkpoint,
@@ -21,28 +18,26 @@ def run_amg(name, model_type, checkpoint):
2118
return prediction_folder
2219

2320

24-
def eval_amg(name, prediction_folder):
21+
def eval_amg(prediction_folder, experiment_folder):
2522
print("Evaluating", prediction_folder)
2623
pred_paths, gt_paths = get_pred_and_gt_paths(prediction_folder)
27-
save_path = os.path.join(get_experiment_folder(name), "results", "amg.csv")
24+
save_path = os.path.join(experiment_folder, "results", "amg.csv")
2825
res = run_evaluation(gt_paths, pred_paths, save_path=save_path)
2926
print(res)
3027

3128

3229
def main():
3330
parser = argparse.ArgumentParser()
34-
parser.add_argument("-n", "--name", required=True)
3531
parser.add_argument(
36-
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
32+
"-m", "--model", type=str, required=True,
3733
help="Provide the model type to initialize the predictor"
3834
)
39-
parser.add_argument("-c", "--checkpoint", type=str, default=None)
35+
parser.add_argument("-c", "--checkpoint", type=str, required=True)
36+
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
4037
args = parser.parse_args()
4138

42-
name = args.name
43-
44-
prediction_folder = run_amg(name, args.model, args.checkpoint)
45-
eval_amg(name, prediction_folder)
39+
prediction_folder = run_amg(args.model, args.checkpoint, args.experiment_folder)
40+
eval_amg(prediction_folder, args.experiment_folder)
4641

4742

4843
if __name__ == "__main__":

finetuning/livecell/evaluation/evaluate_amg.sbatch

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#SBATCH -G A100:1
77
#SBATCH -A nim00007
88

9-
source ~/.bashrc
10-
micromamba activate main
11-
python evaluate_amg.py $@
9+
source activate sam
10+
python evaluate_amg.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_b/livecell_sam/best.pt \
11+
-m vit_b \
12+
-e /scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_b/

finetuning/livecell/evaluation/evaluate_instance_segmentation.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,10 @@
33

44
from micro_sam.evaluation.evaluation import run_evaluation
55
from micro_sam.evaluation.livecell import run_livecell_instance_segmentation_with_decoder
6-
from util import DATA_ROOT, get_checkpoint, get_experiment_folder, get_pred_and_gt_paths
6+
from util import DATA_ROOT, get_pred_and_gt_paths
77

88

9-
def run_instance_segmentation_with_decoder(name, model_type, checkpoint):
10-
if checkpoint is None:
11-
checkpoint, model_type = get_checkpoint(name)
12-
experiment_folder = get_experiment_folder(name)
9+
def run_instance_segmentation_with_decoder(model_type, checkpoint, experiment_folder):
1310
input_folder = DATA_ROOT
1411
prediction_folder = run_livecell_instance_segmentation_with_decoder(
1512
checkpoint,
@@ -21,28 +18,28 @@ def run_instance_segmentation_with_decoder(name, model_type, checkpoint):
2118
return prediction_folder
2219

2320

24-
def eval_instance_segmentation_with_decoder(name, prediction_folder):
21+
def eval_instance_segmentation_with_decoder(prediction_folder, experiment_folder):
2522
print("Evaluating", prediction_folder)
2623
pred_paths, gt_paths = get_pred_and_gt_paths(prediction_folder)
27-
save_path = os.path.join(get_experiment_folder(name), "results", "instance_segmentation_with_decoder.csv")
24+
save_path = os.path.join(experiment_folder, "results", "instance_segmentation_with_decoder.csv")
2825
res = run_evaluation(gt_paths, pred_paths, save_path=save_path)
2926
print(res)
3027

3128

3229
def main():
3330
parser = argparse.ArgumentParser()
34-
parser.add_argument("-n", "--name", required=True)
3531
parser.add_argument(
36-
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
32+
"-m", "--model", type=str, required=True,
3733
help="Provide the model type to initialize the predictor"
3834
)
39-
parser.add_argument("-c", "--checkpoint", type=str, default=None)
35+
parser.add_argument("-c", "--checkpoint", type=str, required=True,)
36+
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
4037
args = parser.parse_args()
4138

42-
name = args.name
43-
44-
prediction_folder = run_instance_segmentation_with_decoder(name, args.model, args.checkpoint)
45-
eval_instance_segmentation_with_decoder(name, prediction_folder)
39+
prediction_folder = run_instance_segmentation_with_decoder(
40+
args.model, args.checkpoint, args.experiment_folder
41+
)
42+
eval_instance_segmentation_with_decoder(prediction_folder, args.experiment_folder)
4643

4744

4845
if __name__ == "__main__":

finetuning/livecell/evaluation/evaluate_instance_segmentation.sbatch

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#SBATCH -G A100:1
77
#SBATCH -A nim00007
88

9-
source ~/.bashrc
10-
micromamba activate main
11-
python evaluate_instance_segmentation.py $@
9+
source activate sam
10+
python evaluate_instance_segmentation.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt \
11+
-m vit_h \
12+
-e /scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_h/

finetuning/livecell/evaluation/iterative_prompting.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from micro_sam.evaluation import inference
88
from micro_sam.evaluation.evaluation import run_evaluation
9-
from util import get_paths, get_experiment_folder, get_model, get_pred_and_gt_paths
9+
from util import get_paths, get_model, get_pred_and_gt_paths
1010

1111

1212
def run_interactive_prompting(exp_folder, predictor, start_with_box_prompt):
@@ -26,16 +26,9 @@ def run_interactive_prompting(exp_folder, predictor, start_with_box_prompt):
2626
return prediction_root
2727

2828

29-
def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name):
29+
def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, exp_folder):
3030
assert os.path.exists(prediction_root), prediction_root
3131

32-
csv_save_dir = f"./iterative_prompting_results/{name}"
33-
os.makedirs(csv_save_dir, exist_ok=True)
34-
csv_path = os.path.join(csv_save_dir, "start_with_box.csv" if start_with_box_prompt else "start_with_point.csv")
35-
if os.path.exists(csv_path):
36-
print("The evaluated results for the expected setting already exist here:", csv_path)
37-
return
38-
3932
prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*")))
4033
list_of_results = []
4134
for pred_folder in prediction_folders:
@@ -46,10 +39,9 @@ def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name)
4639
print(res)
4740

4841
df = pd.concat(list_of_results, ignore_index=True)
49-
df.to_csv(csv_path)
5042

51-
# Also save the results in the experiment folder.
52-
result_folder = os.path.join(get_experiment_folder(name), "results")
43+
# Save the results in the experiment folder.
44+
result_folder = os.path.join(exp_folder, "results")
5345
os.makedirs(result_folder, exist_ok=True)
5446
csv_path = os.path.join(
5547
result_folder,
@@ -60,25 +52,21 @@ def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name)
6052

6153
def main():
6254
parser = argparse.ArgumentParser()
63-
64-
parser.add_argument("-n", "--name", required=True)
6555
parser.add_argument(
66-
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
67-
help="Provide the model type to initialize the predictor"
56+
"-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor"
6857
)
69-
parser.add_argument("-c", "--checkpoint", type=str, default=None)
58+
parser.add_argument("-c", "--checkpoint", type=str, required=True)
59+
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
7060
parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box")
7161
args = parser.parse_args()
7262

73-
name = args.name
7463
start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point
7564

7665
# get the predictor to perform inference
77-
predictor = get_model(name, model_type=args.model, ckpt=args.checkpoint)
66+
predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
7867

79-
exp_folder = get_experiment_folder(name)
80-
prediction_root = run_interactive_prompting(exp_folder, predictor, start_with_box_prompt)
81-
evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name)
68+
prediction_root = run_interactive_prompting(args.experiment_folder, predictor, start_with_box_prompt)
69+
evaluate_interactive_prompting(prediction_root, start_with_box_prompt)
8270

8371

8472
if __name__ == "__main__":

finetuning/livecell/evaluation/iterative_prompting.sbatch

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#SBATCH -G A100:1
77
#SBATCH -A nim00007
88

9-
source ~/.bashrc
10-
micromamba activate main
11-
python iterative_prompting.py $@
9+
source activate sam
10+
python iterative_prompting.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt \
11+
-m vit_h \
12+
-e /scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_h/
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
import shutil
3+
import subprocess
4+
from glob import glob
5+
from datetime import datetime
6+
7+
8+
def write_batch_script(env_name, out_path, inference_setup, checkpoint, model_type, experiment_folder):
9+
"""Writing scripts with different fold-trainings for micro-sam evaluation
10+
"""
11+
batch_script = f"""#!/bin/bash
12+
#SBATCH -c 8
13+
#SBATCH --mem 128G
14+
#SBATCH -t 6:00:00
15+
#SBATCH -p grete:shared
16+
#SBATCH -G A100:1
17+
#SBATCH -A nim00007
18+
#SBATCH --job-name={inference_setup}
19+
20+
source ~/.bashrc
21+
mamba activate {env_name}
22+
python {inference_setup}.py """
23+
24+
_op = out_path[:-3] + f"_{inference_setup}.sh"
25+
26+
# add the finetuned checkpoint
27+
batch_script += f"-c {checkpoint} "
28+
29+
# name of the model configuration
30+
batch_script += f"-m {model_type} "
31+
32+
# experiment folder
33+
batch_script += f"-e {experiment_folder} "
34+
35+
with open(_op, "w") as f:
36+
f.write(batch_script)
37+
38+
39+
def get_batch_script_names(tmp_folder):
40+
tmp_folder = os.path.expanduser(tmp_folder)
41+
os.makedirs(tmp_folder, exist_ok=True)
42+
43+
script_name = "livecell-inference"
44+
45+
dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
46+
tmp_name = script_name + dt
47+
batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")
48+
49+
return batch_script
50+
51+
52+
def submit_slurm():
53+
"""Submit python script that needs gpus with given inputs on a slurm node.
54+
"""
55+
tmp_folder = "./gpu_jobs"
56+
57+
# parameters to run the inference scripts
58+
env_name = "sam"
59+
checkpoint = "/scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt"
60+
model_type = "vit_h"
61+
experiment_folder = "/scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_h/"
62+
63+
all_setups = ["evaluate_amg", "evaluate_instance_segmentation", "iterative_prompting"]
64+
for current_setup in all_setups:
65+
write_batch_script(
66+
env_name=env_name,
67+
out_path=get_batch_script_names(tmp_folder),
68+
inference_setup=current_setup,
69+
checkpoint=checkpoint,
70+
model_type=model_type,
71+
experiment_folder=experiment_folder,
72+
)
73+
74+
for my_script in glob(tmp_folder + "/*"):
75+
cmd = ["sbatch", my_script]
76+
subprocess.run(cmd)
77+
78+
79+
if __name__ == "__main__":
80+
try:
81+
shutil.rmtree("./gpu_jobs")
82+
except FileNotFoundError:
83+
pass
84+
85+
submit_slurm()

finetuning/livecell/evaluation/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_checkpoint(name):
3232
return ckpt, model_type
3333

3434

35-
def get_model(name, model_type=None, ckpt=None):
35+
def get_model(name=None, model_type=None, ckpt=None):
3636
if ckpt is None:
3737
ckpt, model_type = get_checkpoint(name)
3838
assert (ckpt is not None) and (model_type is not None)

0 commit comments

Comments
 (0)