Skip to content

Commit 9b1bac0

Browse files
Add prediction with box prompts to training integration tests
1 parent 0ec2643 commit 9b1bac0

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

test/test_training.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _run_inference_and_check_results(
135135
self.assertEqual(len(pred_paths), len(label_paths))
136136
eval_res = evaluation.run_evaluation(label_paths, pred_paths, verbose=False)
137137
result = eval_res["sa50"].values.item()
138-
# We expect an SA50 > 90%.
138+
# We check against the expected segmentation accuracy.
139139
self.assertGreater(result, expected_sa)
140140

141141
def test_training(self):
@@ -153,21 +153,34 @@ def test_training(self):
153153
self._export_model(checkpoint_path, export_path, model_type)
154154
self.assertTrue(os.path.exists(export_path))
155155

156-
# Check the model with normal inference.
157-
prediction_dir = os.path.join(self.tmp_folder, "predictions")
156+
# Check the model with inference with a single point prompt.
157+
prediction_dir = os.path.join(self.tmp_folder, "predictions-points")
158158
normal_inference = partial(
159159
evaluation.run_inference_with_prompts,
160160
use_points=True, use_boxes=False,
161161
n_positives=1, n_negatives=0,
162-
batch_size=64
162+
batch_size=64,
163163
)
164164
self._run_inference_and_check_results(
165165
export_path, model_type, prediction_dir=prediction_dir,
166166
inference_function=normal_inference, expected_sa=0.9
167167
)
168168

169+
# Check the model with inference with a box point prompt.
170+
prediction_dir = os.path.join(self.tmp_folder, "predictions-boxes")
171+
normal_inference = partial(
172+
evaluation.run_inference_with_prompts,
173+
use_points=False, use_boxes=True,
174+
n_positives=1, n_negatives=0,
175+
batch_size=64,
176+
)
177+
self._run_inference_and_check_results(
178+
export_path, model_type, prediction_dir=prediction_dir,
179+
inference_function=normal_inference, expected_sa=0.95,
180+
)
181+
182+
# Check the model with interactive inference
169183
# TODO
170-
# Check the model with interactivel inference
171184

172185

173186
if __name__ == "__main__":

0 commit comments

Comments
 (0)