88from __future__ import annotations
99
1010import copy
11+ import logging as log
1112import types
1213from contextlib import contextmanager
1314from typing import TYPE_CHECKING , Any , Callable , Iterator , Literal , Sequence
4950class OTXInstanceSegModel (OTXModel ):
5051 """Base class for the Instance Segmentation models used in OTX.
5152
53+ NOTE: OTXInstanceSegModel has many duplicate methods to OTXDetectionModel,
54+ however, it is not a subclass of OTXDetectionModel because it has different
55+ export parameters and different metric computation. Some refactor could be done
56+ to reduce the code duplication in the future.
57+
5258 Args:
5359 label_info (LabelInfoTypes | int | Sequence): Information about the labels used in the model.
5460 If `int` is given, label info will be constructed from number of classes,
@@ -264,35 +270,96 @@ def _export_parameters(self) -> TaskLevelExportParameters:
264270 label_info = modified_label_info ,
265271 )
266272
273+ def test_step (self , batch : OTXDataBatch , batch_idx : int ) -> OTXPredBatch :
274+ """Perform a single test step on a batch of data from the test set.
275+
276+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
277+ labels.
278+ :param batch_idx: The index of the current batch.
279+ """
280+ preds = self .forward (inputs = batch )
281+
282+ if isinstance (preds , OTXBatchLossEntity ):
283+ raise TypeError (preds )
284+
285+ # 1. Filter outputs by threshold
286+ preds = self ._filter_outputs_by_threshold (preds )
287+ metric_inputs = self ._convert_pred_entity_to_compute_metric (preds , batch )
288+
289+ # 2. Update metric
290+ if isinstance (metric_inputs , dict ):
291+ self .metric .update (** metric_inputs )
292+ return preds
293+
294+ if isinstance (metric_inputs , list ) and all (isinstance (inp , dict ) for inp in metric_inputs ):
295+ for inp in metric_inputs :
296+ self .metric .update (** inp )
297+ return preds
298+
299+ raise TypeError (metric_inputs )
300+
301+ def predict_step (
302+ self ,
303+ batch : OTXDataBatch | OTXTileBatchDataEntity ,
304+ batch_idx : int ,
305+ dataloader_idx : int = 0 ,
306+ ) -> OTXPredBatch :
307+ """Step function called during PyTorch Lightning Trainer's predict."""
308+ if self .explain_mode :
309+ return self ._filter_outputs_by_threshold (self .forward_explain (inputs = batch )) # type: ignore[arg-type]
310+
311+ outputs = self ._filter_outputs_by_threshold (self .forward (inputs = batch )) # type: ignore[arg-type]
312+
313+ if isinstance (outputs , OTXBatchLossEntity ):
314+ raise TypeError (outputs )
315+
316+ return outputs
317+
318+ @property
319+ def best_confidence_threshold (self ) -> float :
320+ """Best confidence threshold to filter outputs.
321+
322+ Always returns the current value from hparams, with 0.5 as fallback.
323+ This ensures the threshold is always up-to-date after validation updates it.
324+ """
325+ threshold = self .hparams .get ("best_confidence_threshold" , None )
326+ if threshold is None :
327+ # Only log warning once to avoid spam
328+ if not getattr (self , "_threshold_warning_logged" , False ):
329+ log .warning ("There is no predefined best_confidence_threshold, 0.5 will be used as default." )
330+ self ._threshold_warning_logged = True
331+ return 0.5
332+ return float (threshold )
333+
267334 def on_load_checkpoint (self , ckpt : dict [str , Any ]) -> None :
268335 """Load state_dict from checkpoint.
269336
270- For detection , it is need to update confidence threshold information when
337+ For instance segmentation , it is needed to update confidence threshold and F1 score information when
271338 the metric is FMeasure.
272339 """
273- if best_confidence_threshold := ckpt .get ("confidence_threshold" , None ) or (
274- (hyper_parameters := ckpt .get ("hyper_parameters" , None ))
275- and (best_confidence_threshold := hyper_parameters .get ("best_confidence_threshold" , None ))
340+ hyper_parameters = ckpt .get ("hyper_parameters" , {})
341+
342+ # Load best confidence threshold (legacy and new format)
343+ if best_confidence_threshold := ckpt .get ("confidence_threshold" , None ) or hyper_parameters .get (
344+ "best_confidence_threshold" ,
345+ None ,
276346 ):
277347 self .hparams ["best_confidence_threshold" ] = best_confidence_threshold
278348 super ().on_load_checkpoint (ckpt )
279349
280350 def _log_metrics (self , meter : Metric , key : Literal ["val" , "test" ], ** compute_kwargs ) -> None :
281351 if key == "val" :
282- retval = super ()._log_metrics (meter , key )
352+ super ()._log_metrics (meter , key )
283353
284- # NOTE: Validation metric logging can update `best_confidence_threshold`
285- if (
286- isinstance (meter , MetricCollection )
287- and (fmeasure := getattr (meter , "FMeasure" , None ))
288- and (best_confidence_threshold := getattr (fmeasure , "best_confidence_threshold" , None ))
289- ) or (
290- isinstance (meter , FMeasure )
291- and (best_confidence_threshold := getattr (meter , "best_confidence_threshold" , None ))
292- ):
293- self .hparams ["best_confidence_threshold" ] = best_confidence_threshold
354+ # NOTE: Only update best_confidence_threshold when we achieve a NEW best F1 score
355+ fmeasure = None
356+ if isinstance (meter , MetricCollection ) and (fmeasure := getattr (meter , "FMeasure" , None )):
357+ pass # fmeasure is set
358+ elif isinstance (meter , FMeasure ):
359+ fmeasure = meter
294360
295- return retval
361+ if fmeasure is not None and hasattr (fmeasure , "best_confidence_threshold" ):
362+ self .hparams ["best_confidence_threshold" ] = fmeasure .best_confidence_threshold
296363
297364 if key == "test" :
298365 # NOTE: Test metric logging should use `best_confidence_threshold` found previously.
@@ -301,9 +368,37 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa
301368 {"best_confidence_threshold" : best_confidence_threshold } if best_confidence_threshold else {}
302369 )
303370
304- return super ()._log_metrics (meter , key , ** compute_kwargs )
305-
306- raise ValueError (key )
371+ super ()._log_metrics (meter , key , ** compute_kwargs )
372+
373+ def _filter_outputs_by_threshold (self , outputs : OTXPredBatch ) -> OTXPredBatch :
374+ scores = []
375+ bboxes = []
376+ labels = []
377+ masks = []
378+ polygons = []
379+
380+ for i in range (len (outputs .imgs_info )): # type: ignore[arg-type]
381+ _scores = outputs .scores [i ] if outputs .scores is not None else None
382+ _bboxes = outputs .bboxes [i ] if outputs .bboxes is not None else None
383+ _masks = outputs .masks [i ] if outputs .masks is not None else None
384+ _polygons = outputs .polygons [i ] if outputs .polygons is not None else None
385+ _labels = outputs .labels [i ] if outputs .labels is not None else None
386+
387+ filtered_idx = torch .where (_scores > self .best_confidence_threshold )
388+ scores .append (_scores [filtered_idx ])
389+ bboxes .append (_bboxes [filtered_idx ])
390+ labels .append (_labels [filtered_idx ])
391+ if _masks is not None and len (_masks ) > 0 :
392+ masks .append (_masks [filtered_idx ])
393+ if _polygons is not None and len (_polygons ) > 0 :
394+ polygons .append (_polygons [filtered_idx ])
395+
396+ outputs .scores = scores
397+ outputs .bboxes = bboxes
398+ outputs .labels = labels
399+ outputs .masks = masks
400+ outputs .polygons = polygons
401+ return outputs
307402
308403 def _convert_pred_entity_to_compute_metric (
309404 self ,
0 commit comments