Skip to content

Commit 4fb7e61

Browse files
authored
Fix confidence threshold cache invalidation and filtering logic (#4498)
* Refactor confidence threshold handling in detection and instance segmentation models * adding stage parameter to model methods for validation and testing * Refactor metric computation in OTX models by removing stage parameter and consolidating test step logic * fix inst-seg _filter_outputs_by_threshold * Remove best_confidence_threshold_list from checkpoint during save and add unit tests for detection model confidence threshold logic. * Fix format * Enhance unit tests for detection threshold logic to ensure compatibility with Python 3.10 * Enhance unit tests for detection threshold logic to ensure compatibility with Python 3.10 * Fix tests * Fix format * fix tests * update unit test * Removing best_confidence_threshold_list and updating related unit tests for checkpoint functionality. * Refactor checkpoint saving in OTXModel to remove unnecessary line and update comments in OTXDetectionModel for clarity on best_confidence_threshold usage.
1 parent 8eb3ae9 commit 4fb7e61

File tree

4 files changed

+638
-56
lines changed

4 files changed

+638
-56
lines changed

src/otx/backend/native/models/detection/base.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,33 @@ def __init__(
8282
self.model.feature_vector_fn = feature_vector_fn
8383
self.model.explain_fn = self.get_explain_fn()
8484

85-
def validation_step(self, batch: OTXDataBatch, batch_idx: int) -> OTXPredBatch:
86-
"""Perform a single validation step on a batch of data from the validation set.
87-
88-
:param batch: A batch of data (a tuple) containing the input tensor of images and target
89-
labels.
90-
:param batch_idx: The index of the current batch.
91-
"""
92-
return self._filter_outputs_by_threshold(super().validation_step(batch, batch_idx))
93-
9485
def test_step(self, batch: OTXDataBatch, batch_idx: int) -> OTXPredBatch:
9586
"""Perform a single test step on a batch of data from the test set.
9687
9788
:param batch: A batch of data (a tuple) containing the input tensor of images and target
9889
labels.
9990
:param batch_idx: The index of the current batch.
10091
"""
101-
return self._filter_outputs_by_threshold(super().test_step(batch, batch_idx))
92+
preds = self.forward(inputs=batch)
93+
94+
if isinstance(preds, OTXBatchLossEntity):
95+
raise TypeError(preds)
96+
97+
# 1. Filter outputs by threshold
98+
preds = self._filter_outputs_by_threshold(preds)
99+
metric_inputs = self._convert_pred_entity_to_compute_metric(preds, batch)
100+
101+
# 2. Update metric
102+
if isinstance(metric_inputs, dict):
103+
self.metric.update(**metric_inputs)
104+
return preds
105+
106+
if isinstance(metric_inputs, list) and all(isinstance(inp, dict) for inp in metric_inputs):
107+
for inp in metric_inputs:
108+
self.metric.update(**inp)
109+
return preds
110+
111+
raise TypeError(metric_inputs)
102112

103113
def predict_step(
104114
self,
@@ -118,6 +128,10 @@ def predict_step(
118128
return outputs
119129

120130
def _filter_outputs_by_threshold(self, outputs: OTXPredBatch) -> OTXPredBatch:
131+
# NOTE: best_confidence_threshold comes from:
132+
# 1. During validation: FMeasure metric computes optimal threshold, stored in hparams via _log_metrics
133+
# 2. During test/predict: Uses the threshold computed during validation (from hparams)
134+
# 3. If no threshold available: defaults to 0.5
121135
scores = []
122136
bboxes = []
123137
labels = []
@@ -316,53 +330,64 @@ def _convert_pred_entity_to_compute_metric(
316330
def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None:
317331
"""Load state_dict from checkpoint.
318332
319-
For detection, it is need to update confidence threshold information when
333+
For detection, it is needed to update confidence threshold and F1 score information when
320334
the metric is FMeasure.
321335
"""
322-
if best_confidence_threshold := ckpt.get("confidence_threshold", None) or (
323-
(hyper_parameters := ckpt.get("hyper_parameters", None))
324-
and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None))
336+
hyper_parameters = ckpt.get("hyper_parameters", {})
337+
338+
# Load best confidence threshold (legacy and new format)
339+
if best_confidence_threshold := ckpt.get("confidence_threshold", None) or hyper_parameters.get(
340+
"best_confidence_threshold",
341+
None,
325342
):
326343
self.hparams["best_confidence_threshold"] = best_confidence_threshold
327344
super().on_load_checkpoint(ckpt)
328345

329346
def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
347+
"""This function is called every epoch.
348+
349+
Args:
350+
meter: Metric object
351+
key: "val" or "test"
352+
compute_kwargs: Additional keyword arguments for the metric computation
353+
354+
"""
330355
if key == "val":
331-
retval = super()._log_metrics(meter, key)
356+
super()._log_metrics(meter, key)
332357

333-
# NOTE: Validation metric logging can update `best_confidence_threshold`
334-
if (
335-
isinstance(meter, MetricCollection)
336-
and (fmeasure := getattr(meter, "FMeasure", None))
337-
and (best_confidence_threshold := getattr(fmeasure, "best_confidence_threshold", None))
338-
) or (
339-
isinstance(meter, FMeasure)
340-
and (best_confidence_threshold := getattr(meter, "best_confidence_threshold", None))
341-
):
342-
self.hparams["best_confidence_threshold"] = best_confidence_threshold
358+
fmeasure = None
359+
if isinstance(meter, MetricCollection) and (fmeasure := getattr(meter, "FMeasure", None)):
360+
pass # fmeasure is set
361+
elif isinstance(meter, FMeasure):
362+
fmeasure = meter
343363

344-
return retval
364+
if fmeasure is not None and hasattr(fmeasure, "best_confidence_threshold"):
365+
self.hparams["best_confidence_threshold"] = fmeasure.best_confidence_threshold
345366

346367
if key == "test":
347-
# NOTE: Test metric logging should use `best_confidence_threshold` found previously.
368+
# NOTE: Test metric logging should use `best_confidence_threshold` in the loaded checkpoint.
348369
best_confidence_threshold = self.hparams.get("best_confidence_threshold", None)
349370
compute_kwargs = (
350371
{"best_confidence_threshold": best_confidence_threshold} if best_confidence_threshold else {}
351372
)
352373

353-
return super()._log_metrics(meter, key, **compute_kwargs)
354-
355-
raise ValueError(key)
374+
super()._log_metrics(meter, key, **compute_kwargs)
356375

357376
@property
358377
def best_confidence_threshold(self) -> float:
359-
"""Best confidence threshold to filter outputs."""
360-
if not hasattr(self, "_best_confidence_threshold"):
361-
self._best_confidence_threshold = self.hparams.get("best_confidence_threshold", None)
362-
if self._best_confidence_threshold is None:
378+
"""Best confidence threshold to filter outputs.
379+
380+
Always returns the current value from hparams, with 0.5 as fallback.
381+
This ensures the threshold is always up-to-date after validation updates it.
382+
"""
383+
threshold = self.hparams.get("best_confidence_threshold", None)
384+
if threshold is None:
385+
# Only log warning once to avoid spam
386+
if not getattr(self, "_threshold_warning_logged", False):
363387
log.warning("There is no predefined best_confidence_threshold, 0.5 will be used as default.")
364-
self._best_confidence_threshold = 0.5
365-
return self._best_confidence_threshold
388+
self._threshold_warning_logged = True
389+
return 0.5
390+
return float(threshold)
366391

367392
def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: # type: ignore[override]
368393
"""Returns a dummy input for detection model."""

src/otx/backend/native/models/instance_segmentation/base.py

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import copy
11+
import logging as log
1112
import types
1213
from contextlib import contextmanager
1314
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Sequence
@@ -49,6 +50,11 @@
4950
class OTXInstanceSegModel(OTXModel):
5051
"""Base class for the Instance Segmentation models used in OTX.
5152
53+
NOTE: OTXInstanceSegModel has many duplicate methods to OTXDetectionModel,
54+
however, it is not a subclass of OTXDetectionModel because it has different
55+
export parameters and different metric computation. Some refactor could be done
56+
to reduce the code duplication in the future.
57+
5258
Args:
5359
label_info (LabelInfoTypes | int | Sequence): Information about the labels used in the model.
5460
If `int` is given, label info will be constructed from number of classes,
@@ -264,35 +270,96 @@ def _export_parameters(self) -> TaskLevelExportParameters:
264270
label_info=modified_label_info,
265271
)
266272

273+
def test_step(self, batch: OTXDataBatch, batch_idx: int) -> OTXPredBatch:
274+
"""Perform a single test step on a batch of data from the test set.
275+
276+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
277+
labels.
278+
:param batch_idx: The index of the current batch.
279+
"""
280+
preds = self.forward(inputs=batch)
281+
282+
if isinstance(preds, OTXBatchLossEntity):
283+
raise TypeError(preds)
284+
285+
# 1. Filter outputs by threshold
286+
preds = self._filter_outputs_by_threshold(preds)
287+
metric_inputs = self._convert_pred_entity_to_compute_metric(preds, batch)
288+
289+
# 2. Update metric
290+
if isinstance(metric_inputs, dict):
291+
self.metric.update(**metric_inputs)
292+
return preds
293+
294+
if isinstance(metric_inputs, list) and all(isinstance(inp, dict) for inp in metric_inputs):
295+
for inp in metric_inputs:
296+
self.metric.update(**inp)
297+
return preds
298+
299+
raise TypeError(metric_inputs)
300+
301+
def predict_step(
302+
self,
303+
batch: OTXDataBatch | OTXTileBatchDataEntity,
304+
batch_idx: int,
305+
dataloader_idx: int = 0,
306+
) -> OTXPredBatch:
307+
"""Step function called during PyTorch Lightning Trainer's predict."""
308+
if self.explain_mode:
309+
return self._filter_outputs_by_threshold(self.forward_explain(inputs=batch)) # type: ignore[arg-type]
310+
311+
outputs = self._filter_outputs_by_threshold(self.forward(inputs=batch)) # type: ignore[arg-type]
312+
313+
if isinstance(outputs, OTXBatchLossEntity):
314+
raise TypeError(outputs)
315+
316+
return outputs
317+
318+
@property
319+
def best_confidence_threshold(self) -> float:
320+
"""Best confidence threshold to filter outputs.
321+
322+
Always returns the current value from hparams, with 0.5 as fallback.
323+
This ensures the threshold is always up-to-date after validation updates it.
324+
"""
325+
threshold = self.hparams.get("best_confidence_threshold", None)
326+
if threshold is None:
327+
# Only log warning once to avoid spam
328+
if not getattr(self, "_threshold_warning_logged", False):
329+
log.warning("There is no predefined best_confidence_threshold, 0.5 will be used as default.")
330+
self._threshold_warning_logged = True
331+
return 0.5
332+
return float(threshold)
333+
267334
def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None:
268335
"""Load state_dict from checkpoint.
269336
270-
For detection, it is need to update confidence threshold information when
337+
For instance segmentation, it is needed to update confidence threshold and F1 score information when
271338
the metric is FMeasure.
272339
"""
273-
if best_confidence_threshold := ckpt.get("confidence_threshold", None) or (
274-
(hyper_parameters := ckpt.get("hyper_parameters", None))
275-
and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None))
340+
hyper_parameters = ckpt.get("hyper_parameters", {})
341+
342+
# Load best confidence threshold (legacy and new format)
343+
if best_confidence_threshold := ckpt.get("confidence_threshold", None) or hyper_parameters.get(
344+
"best_confidence_threshold",
345+
None,
276346
):
277347
self.hparams["best_confidence_threshold"] = best_confidence_threshold
278348
super().on_load_checkpoint(ckpt)
279349

280350
def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
281351
if key == "val":
282-
retval = super()._log_metrics(meter, key)
352+
super()._log_metrics(meter, key)
283353

284-
# NOTE: Validation metric logging can update `best_confidence_threshold`
285-
if (
286-
isinstance(meter, MetricCollection)
287-
and (fmeasure := getattr(meter, "FMeasure", None))
288-
and (best_confidence_threshold := getattr(fmeasure, "best_confidence_threshold", None))
289-
) or (
290-
isinstance(meter, FMeasure)
291-
and (best_confidence_threshold := getattr(meter, "best_confidence_threshold", None))
292-
):
293-
self.hparams["best_confidence_threshold"] = best_confidence_threshold
354+
# NOTE: Only update best_confidence_threshold when we achieve a NEW best F1 score
355+
fmeasure = None
356+
if isinstance(meter, MetricCollection) and (fmeasure := getattr(meter, "FMeasure", None)):
357+
pass # fmeasure is set
358+
elif isinstance(meter, FMeasure):
359+
fmeasure = meter
294360

295-
return retval
361+
if fmeasure is not None and hasattr(fmeasure, "best_confidence_threshold"):
362+
self.hparams["best_confidence_threshold"] = fmeasure.best_confidence_threshold
296363

297364
if key == "test":
298365
# NOTE: Test metric logging should use `best_confidence_threshold` found previously.
@@ -301,9 +368,37 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa
301368
{"best_confidence_threshold": best_confidence_threshold} if best_confidence_threshold else {}
302369
)
303370

304-
return super()._log_metrics(meter, key, **compute_kwargs)
305-
306-
raise ValueError(key)
371+
super()._log_metrics(meter, key, **compute_kwargs)
372+
373+
def _filter_outputs_by_threshold(self, outputs: OTXPredBatch) -> OTXPredBatch:
374+
scores = []
375+
bboxes = []
376+
labels = []
377+
masks = []
378+
polygons = []
379+
380+
for i in range(len(outputs.imgs_info)): # type: ignore[arg-type]
381+
_scores = outputs.scores[i] if outputs.scores is not None else None
382+
_bboxes = outputs.bboxes[i] if outputs.bboxes is not None else None
383+
_masks = outputs.masks[i] if outputs.masks is not None else None
384+
_polygons = outputs.polygons[i] if outputs.polygons is not None else None
385+
_labels = outputs.labels[i] if outputs.labels is not None else None
386+
387+
filtered_idx = torch.where(_scores > self.best_confidence_threshold)
388+
scores.append(_scores[filtered_idx])
389+
bboxes.append(_bboxes[filtered_idx])
390+
labels.append(_labels[filtered_idx])
391+
if _masks is not None and len(_masks) > 0:
392+
masks.append(_masks[filtered_idx])
393+
if _polygons is not None and len(_polygons) > 0:
394+
polygons.append(_polygons[filtered_idx])
395+
396+
outputs.scores = scores
397+
outputs.bboxes = bboxes
398+
outputs.labels = labels
399+
outputs.masks = masks
400+
outputs.polygons = polygons
401+
return outputs
307402

308403
def _convert_pred_entity_to_compute_metric(
309404
self,

src/otx/backend/native/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def mock_modules_for_chkpt() -> Iterator[None]:
8585
sys.modules["otx.core.types"] = otx.types
8686
sys.modules["otx.core.types.task"] = otx.types.task
8787
sys.modules["otx.core.types.label"] = otx.types.label
88-
sys.modules["otx.core.model"] = otx.backend.native.models
88+
sys.modules["otx.core.model"] = otx.backend.native.models # type: ignore[attr-defined]
8989
sys.modules["otx.core.metrics"] = otx.metrics
9090

9191
yield

0 commit comments

Comments
 (0)