Skip to content

Commit 138eaaa

Browse files
Fix issues in livecell evaluation for generalist
1 parent a16eca2 commit 138eaaa

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

finetuning/generalists/util.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,13 @@ def evaluate_checkpoint_for_dataset(
114114
prompt_save_dir=prompt_dir,
115115
)
116116

117-
pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif")))
117+
if dataset == "livecell":
118+
pred_paths = [
119+
os.path.join(prediction_dir, os.path.basename(gt_path)) for gt_path in test_gt_paths
120+
]
121+
assert all(os.path.exists(pred_path) for pred_path in pred_paths)
122+
else:
123+
pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif")))
118124
result_path = os.path.join(result_dir, f"{setting_name}.csv")
119125
os.makedirs(Path(result_path).parent, exist_ok=True)
120126

@@ -146,7 +152,14 @@ def evaluate_checkpoint_for_dataset(
146152
amg_generate_kwargs=best_settings,
147153
)
148154

149-
pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif")))
155+
if dataset == "livecell":
156+
pred_paths = [
157+
os.path.join(prediction_dir, os.path.basename(gt_path)) for gt_path in test_gt_paths
158+
]
159+
assert all(os.path.exists(pred_path) for pred_path in pred_paths)
160+
else:
161+
pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif")))
162+
150163
result_path = os.path.join(result_dir, "amg.csv")
151164
os.makedirs(Path(result_path).parent, exist_ok=True)
152165

0 commit comments

Comments
 (0)