Skip to content

Commit ac211b7

Browse files
authored
Minor fix to livecell eval. scripts (#316)
* Fix evaluation script * Update precomputing embeddings for val + other minor fixes
1 parent 8eac7f9 commit ac211b7

File tree

5 files changed

+47
-25
lines changed

5 files changed

+47
-25
lines changed

finetuning/livecell/evaluation/iterative_prompting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main():
6666
predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
6767

6868
prediction_root = run_interactive_prompting(args.experiment_folder, predictor, start_with_box_prompt)
69-
evaluate_interactive_prompting(prediction_root, start_with_box_prompt)
69+
evaluate_interactive_prompting(prediction_root, start_with_box_prompt, args.experiment_folder)
7070

7171

7272
if __name__ == "__main__":

finetuning/livecell/evaluation/precompute_embeddings.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22
import os
33

44
from micro_sam.evaluation import precompute_all_embeddings
5-
from util import get_paths, get_model, get_experiment_folder
5+
from util import get_paths, get_model
66

77

88
def main():
99
parser = argparse.ArgumentParser()
10-
parser.add_argument("-n", "--name", required=True)
11-
parser.add_argument("-m", "--model")
12-
parser.add_argument("-c", "--checkpoint")
13-
10+
parser.add_argument("-m", "--model", type=str, required=True)
11+
parser.add_argument("-c", "--checkpoint", type=str, required=True)
12+
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
1413
args = parser.parse_args()
15-
name = args.name
1614

17-
image_paths, _ = get_paths()
18-
predictor = get_model(name, model_type=args.model, ckpt=args.checkpoint)
19-
exp_folder = get_experiment_folder(name)
20-
embedding_dir = os.path.join(exp_folder, "embeddings")
15+
predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
16+
embedding_dir = os.path.join(args.experiment_folder, "embeddings")
2117
os.makedirs(embedding_dir, exist_ok=True)
18+
19+
# getting the embeddings for the test set
20+
image_paths, _ = get_paths("test")
21+
precompute_all_embeddings(predictor, image_paths, embedding_dir)
22+
23+
# getting the embeddings for the val set
24+
image_paths, _ = get_paths("val")
2225
precompute_all_embeddings(predictor, image_paths, embedding_dir)
2326

2427

finetuning/livecell/evaluation/precompute_embeddings.sbatch

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@
99

1010
source ~/.bashrc
1111
micromamba activate main
12-
python precompute_embeddings.py $@
12+
python precompute_embeddings.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_b/lm_generalist_sam/best.pt \
13+
-m vit_b \
14+
-e /scratch/projects/nim00007/sam/experiments/new_models/generalists/livecell/vit_b/

finetuning/livecell/evaluation/submit_evaluation.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from datetime import datetime
66

77

8-
def write_batch_script(env_name, out_path, inference_setup, checkpoint, model_type, experiment_folder):
8+
def write_batch_script(env_name, out_path, inference_setup, checkpoint, model_type, experiment_folder, delay=True):
99
"""Writing scripts with different fold-trainings for micro-sam evaluation
1010
"""
1111
batch_script = f"""#!/bin/bash
@@ -17,24 +17,40 @@ def write_batch_script(env_name, out_path, inference_setup, checkpoint, model_ty
1717
#SBATCH -A nim00007
1818
#SBATCH --job-name={inference_setup}
1919
20-
source ~/.bashrc
21-
mamba activate {env_name}
22-
python {inference_setup}.py """
20+
source ~/.bashrc
21+
mamba activate {env_name} \n"""
22+
23+
if delay:
24+
batch_script += "sleep 10m \n"
25+
26+
# python script
27+
python_script = f"python {inference_setup}.py "
2328

2429
_op = out_path[:-3] + f"_{inference_setup}.sh"
2530

2631
# add the finetuned checkpoint
27-
batch_script += f"-c {checkpoint} "
32+
python_script += f"-c {checkpoint} "
2833

2934
# name of the model configuration
30-
batch_script += f"-m {model_type} "
35+
python_script += f"-m {model_type} "
3136

3237
# experiment folder
33-
batch_script += f"-e {experiment_folder} "
38+
python_script += f"-e {experiment_folder} "
39+
40+
# let's add the python script to the bash script
41+
batch_script += python_script
3442

3543
with open(_op, "w") as f:
3644
f.write(batch_script)
3745

46+
# we run the first prompt for iterative once starting with point, and then starting with box (below)
47+
if inference_setup == "iterative_prompting":
48+
batch_script += "--box "
49+
50+
new_path = out_path[:-3] + f"_{inference_setup}_box.sh"
51+
with open(new_path, "w") as f:
52+
f.write(batch_script)
53+
3854

3955
def get_batch_script_names(tmp_folder):
4056
tmp_folder = os.path.expanduser(tmp_folder)
@@ -56,11 +72,11 @@ def submit_slurm():
5672

5773
# parameters to run the inference scripts
5874
env_name = "sam"
59-
checkpoint = "/scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt"
60-
model_type = "vit_h"
61-
experiment_folder = "/scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_h/"
75+
model_type = "vit_b"
76+
checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/lm_generalist_sam/best.pt"
77+
experiment_folder = f"/scratch/projects/nim00007/sam/experiments/new_models/generalists/livecell/{model_type}/"
6278

63-
all_setups = ["evaluate_amg", "evaluate_instance_segmentation", "iterative_prompting"]
79+
all_setups = ["precompute_embeddings", "evaluate_amg", "evaluate_instance_segmentation", "iterative_prompting"]
6480
for current_setup in all_setups:
6581
write_batch_script(
6682
env_name=env_name,
@@ -69,6 +85,7 @@ def submit_slurm():
6985
checkpoint=checkpoint,
7086
model_type=model_type,
7187
experiment_folder=experiment_folder,
88+
delay=False if current_setup == "precompute_embeddings" else True
7289
)
7390

7491
for my_script in glob(tmp_folder + "/*"):

finetuning/livecell/evaluation/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
}
2020

2121

22-
def get_paths():
23-
return _get_livecell_paths(DATA_ROOT)
22+
def get_paths(split="test"):
23+
return _get_livecell_paths(DATA_ROOT, split=split)
2424

2525

2626
def get_checkpoint(name):

0 commit comments

Comments
 (0)