@@ -135,7 +135,7 @@ def _run_inference_and_check_results(
135135 self .assertEqual (len (pred_paths ), len (label_paths ))
136136 eval_res = evaluation .run_evaluation (label_paths , pred_paths , verbose = False )
137137 result = eval_res ["sa50" ].values .item ()
138- # We expect an SA50 > 90% .
138+ # We check against the expected segmentation accuracy .
139139 self .assertGreater (result , expected_sa )
140140
141141 def test_training (self ):
@@ -153,21 +153,34 @@ def test_training(self):
153153 self ._export_model (checkpoint_path , export_path , model_type )
154154 self .assertTrue (os .path .exists (export_path ))
155155
156- # Check the model with normal inference.
157- prediction_dir = os .path .join (self .tmp_folder , "predictions" )
156+ # Check the model with inference with a single point prompt .
157+ prediction_dir = os .path .join (self .tmp_folder , "predictions-points " )
158158 normal_inference = partial (
159159 evaluation .run_inference_with_prompts ,
160160 use_points = True , use_boxes = False ,
161161 n_positives = 1 , n_negatives = 0 ,
162- batch_size = 64
162+ batch_size = 64 ,
163163 )
164164 self ._run_inference_and_check_results (
165165 export_path , model_type , prediction_dir = prediction_dir ,
166166 inference_function = normal_inference , expected_sa = 0.9
167167 )
168168
169+ # Check the model with inference with a box point prompt.
170+ prediction_dir = os .path .join (self .tmp_folder , "predictions-boxes" )
171+ normal_inference = partial (
172+ evaluation .run_inference_with_prompts ,
173+ use_points = False , use_boxes = True ,
174+ n_positives = 1 , n_negatives = 0 ,
175+ batch_size = 64 ,
176+ )
177+ self ._run_inference_and_check_results (
178+ export_path , model_type , prediction_dir = prediction_dir ,
179+ inference_function = normal_inference , expected_sa = 0.95 ,
180+ )
181+
182+ # Check the model with interactive inference
169183 # TODO
170- # Check the model with interactivel inference
171184
172185
173186if __name__ == "__main__" :
0 commit comments