66from tqdm import tqdm
77from pathlib import Path
88from functools import partial
9- from typing import Optional , Union
9+ from typing import Optional , Union , Dict , Any
1010
1111import h5py
1212import numpy as np
@@ -124,6 +124,9 @@ def generate_data_for_model_comparison(
124124 checkpoint1 : Optional [Union [str , os .PathLike ]] = None ,
125125 checkpoint2 : Optional [Union [str , os .PathLike ]] = None ,
126126 checkpoint3 : Optional [Union [str , os .PathLike ]] = None ,
127+ peft_kwargs1 : Optional [Dict [str , Any ]] = None ,
128+ peft_kwargs2 : Optional [Dict [str , Any ]] = None ,
129+ peft_kwargs3 : Optional [Dict [str , Any ]] = None ,
127130) -> None :
128131 """Generate samples for qualitative model comparison.
129132
@@ -149,11 +152,11 @@ def generate_data_for_model_comparison(
149152 get_box_prompts = True ,
150153 )
151154
152- predictor1 = util .get_sam_model (model_type = model_type1 , checkpoint_path = checkpoint1 )
153- predictor2 = util .get_sam_model (model_type = model_type2 , checkpoint_path = checkpoint2 )
155+ predictor1 = util .get_sam_model (model_type = model_type1 , checkpoint_path = checkpoint1 , peft_kwargs = peft_kwargs1 )
156+ predictor2 = util .get_sam_model (model_type = model_type2 , checkpoint_path = checkpoint2 , peft_kwargs = peft_kwargs2 )
154157
155158 if model_type3 is not None :
156- predictor3 = util .get_sam_model (model_type = model_type3 , checkpoint_path = checkpoint3 )
159+ predictor3 = util .get_sam_model (model_type = model_type3 , checkpoint_path = checkpoint3 , peft_kwargs = peft_kwargs3 )
157160 else :
158161 predictor3 = None
159162
@@ -262,8 +265,8 @@ def _overlay_points(im, prompt, radius):
262265
263266
264267def _compare_eval (
265- f , eval_result , advantage_column , n_images_per_sample , prefix ,
266- sample_name , plot_folder , point_radius , outline_dilation , have_model3 ,
268+ f , eval_result , advantage_column , n_images_per_sample , prefix , sample_name ,
269+ plot_folder , point_radius , outline_dilation , have_model3 , enhance_image ,
267270):
268271 result = eval_result .sort_values (advantage_column , ascending = False ).iloc [:n_images_per_sample ]
269272 n_rows = result .shape [0 ]
@@ -313,7 +316,11 @@ def plot_ax(axis, i, row):
313316 else :
314317 prompt = (g .attrs ["point_coords" ] - offset , g .attrs ["point_labels" ])
315318
316- im = _enhance_image (image [bb ])
319+ if enhance_image :
320+ im = _enhance_image (image [bb ])
321+ else :
322+ im = image [bb ]
323+
317324 gt , mask1 , mask2 = gt [bb ], mask1 [bb ], mask2 [bb ]
318325
319326 if have_model3 :
@@ -364,7 +371,7 @@ def plot_ax(axis, i, row):
364371
365372def _compare_prompts (
366373 f , prefix , n_images_per_sample , min_size , sample_name , plot_folder ,
367- point_radius , outline_dilation , have_model3 ,
374+ point_radius , outline_dilation , have_model3 , enhance_image ,
368375):
369376 box_eval = _evaluate_samples (f , prefix , min_size )
370377 if plot_folder is None :
@@ -376,16 +383,16 @@ def _compare_prompts(
376383 os .makedirs (plot_folder2 , exist_ok = True )
377384 _compare_eval (
378385 f , box_eval , "advantage1" , n_images_per_sample , prefix , sample_name , plot_folder1 ,
379- point_radius , outline_dilation , have_model3 ,
386+ point_radius , outline_dilation , have_model3 , enhance_image ,
380387 )
381388 _compare_eval (
382389 f , box_eval , "advantage2" , n_images_per_sample , prefix , sample_name , plot_folder2 ,
383- point_radius , outline_dilation , have_model3 ,
390+ point_radius , outline_dilation , have_model3 , enhance_image ,
384391 )
385392
386393
387394def _compare_models (
388- path , n_images_per_sample , min_size , plot_folder , point_radius , outline_dilation , have_model3 ,
395+ path , n_images_per_sample , min_size , plot_folder , point_radius , outline_dilation , have_model3 , enhance_image ,
389396):
390397 sample_name = Path (path ).stem
391398 with h5py .File (path , "r" ) as f :
@@ -396,11 +403,11 @@ def _compare_models(
396403 plot_folder_box = os .path .join (plot_folder , "box" )
397404 _compare_prompts (
398405 f , "points" , n_images_per_sample , min_size , sample_name , plot_folder_points ,
399- point_radius , outline_dilation , have_model3 ,
406+ point_radius , outline_dilation , have_model3 , enhance_image ,
400407 )
401408 _compare_prompts (
402409 f , "box" , n_images_per_sample , min_size , sample_name , plot_folder_box ,
403- point_radius , outline_dilation , have_model3 ,
410+ point_radius , outline_dilation , have_model3 , enhance_image ,
404411 )
405412
406413
@@ -412,6 +419,7 @@ def model_comparison(
412419 point_radius : int = 4 ,
413420 outline_dilation : int = 0 ,
414421 have_model3 = False ,
422+ enhance_image = True ,
415423) -> None :
416424 """Create images for a qualitative model comparision.
417425
@@ -422,11 +430,13 @@ def model_comparison(
422430 plot_folder: The folder where to save the plots. If not given the plots will be displayed.
423431 point_radius: The radius of the point overlay.
424432 outline_dilation: The dilation factor of the outline overlay.
433+ enhance_image: Whether to enhance the input image.
425434 """
426435 files = glob (os .path .join (output_folder , "*.h5" ))
427436 for path in tqdm (files ):
428437 _compare_models (
429- path , n_images_per_sample , min_size , plot_folder , point_radius , outline_dilation , have_model3 ,
438+ path , n_images_per_sample , min_size , plot_folder , point_radius ,
439+ outline_dilation , have_model3 , enhance_image ,
430440 )
431441
432442
0 commit comments