Skip to content

Commit 031d9fb

Browse files
authored
Minor update to support peft kwargs in qualitative comparsion scripts (#845)
Make qualitative comparison scripts flexible!
1 parent d930618 commit 031d9fb

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

micro_sam/evaluation/model_comparison.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tqdm import tqdm
77
from pathlib import Path
88
from functools import partial
9-
from typing import Optional, Union
9+
from typing import Optional, Union, Dict, Any
1010

1111
import h5py
1212
import numpy as np
@@ -124,6 +124,9 @@ def generate_data_for_model_comparison(
124124
checkpoint1: Optional[Union[str, os.PathLike]] = None,
125125
checkpoint2: Optional[Union[str, os.PathLike]] = None,
126126
checkpoint3: Optional[Union[str, os.PathLike]] = None,
127+
peft_kwargs1: Optional[Dict[str, Any]] = None,
128+
peft_kwargs2: Optional[Dict[str, Any]] = None,
129+
peft_kwargs3: Optional[Dict[str, Any]] = None,
127130
) -> None:
128131
"""Generate samples for qualitative model comparison.
129132
@@ -149,11 +152,11 @@ def generate_data_for_model_comparison(
149152
get_box_prompts=True,
150153
)
151154

152-
predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1)
153-
predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2)
155+
predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1)
156+
predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2)
154157

155158
if model_type3 is not None:
156-
predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3)
159+
predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3)
157160
else:
158161
predictor3 = None
159162

@@ -262,8 +265,8 @@ def _overlay_points(im, prompt, radius):
262265

263266

264267
def _compare_eval(
265-
f, eval_result, advantage_column, n_images_per_sample, prefix,
266-
sample_name, plot_folder, point_radius, outline_dilation, have_model3,
268+
f, eval_result, advantage_column, n_images_per_sample, prefix, sample_name,
269+
plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
267270
):
268271
result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample]
269272
n_rows = result.shape[0]
@@ -313,7 +316,11 @@ def plot_ax(axis, i, row):
313316
else:
314317
prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"])
315318

316-
im = _enhance_image(image[bb])
319+
if enhance_image:
320+
im = _enhance_image(image[bb])
321+
else:
322+
im = image[bb]
323+
317324
gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb]
318325

319326
if have_model3:
@@ -364,7 +371,7 @@ def plot_ax(axis, i, row):
364371

365372
def _compare_prompts(
366373
f, prefix, n_images_per_sample, min_size, sample_name, plot_folder,
367-
point_radius, outline_dilation, have_model3,
374+
point_radius, outline_dilation, have_model3, enhance_image,
368375
):
369376
box_eval = _evaluate_samples(f, prefix, min_size)
370377
if plot_folder is None:
@@ -376,16 +383,16 @@ def _compare_prompts(
376383
os.makedirs(plot_folder2, exist_ok=True)
377384
_compare_eval(
378385
f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1,
379-
point_radius, outline_dilation, have_model3,
386+
point_radius, outline_dilation, have_model3, enhance_image,
380387
)
381388
_compare_eval(
382389
f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2,
383-
point_radius, outline_dilation, have_model3,
390+
point_radius, outline_dilation, have_model3, enhance_image,
384391
)
385392

386393

387394
def _compare_models(
388-
path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3,
395+
path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
389396
):
390397
sample_name = Path(path).stem
391398
with h5py.File(path, "r") as f:
@@ -396,11 +403,11 @@ def _compare_models(
396403
plot_folder_box = os.path.join(plot_folder, "box")
397404
_compare_prompts(
398405
f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points,
399-
point_radius, outline_dilation, have_model3,
406+
point_radius, outline_dilation, have_model3, enhance_image,
400407
)
401408
_compare_prompts(
402409
f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box,
403-
point_radius, outline_dilation, have_model3,
410+
point_radius, outline_dilation, have_model3, enhance_image,
404411
)
405412

406413

@@ -412,6 +419,7 @@ def model_comparison(
412419
point_radius: int = 4,
413420
outline_dilation: int = 0,
414421
have_model3=False,
422+
enhance_image=True,
415423
) -> None:
416424
"""Create images for a qualitative model comparision.
417425
@@ -422,11 +430,13 @@ def model_comparison(
422430
plot_folder: The folder where to save the plots. If not given the plots will be displayed.
423431
point_radius: The radius of the point overlay.
424432
outline_dilation: The dilation factor of the outline overlay.
433+
enhance_image: Whether to enhance the input image.
425434
"""
426435
files = glob(os.path.join(output_folder, "*.h5"))
427436
for path in tqdm(files):
428437
_compare_models(
429-
path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3,
438+
path, n_images_per_sample, min_size, plot_folder, point_radius,
439+
outline_dilation, have_model3, enhance_image,
430440
)
431441

432442

0 commit comments

Comments
 (0)