Skip to content

Commit 8286197

Browse files
authored
Fix loading/saving checkpoints in OTX (#4433)
* Minor fix to torch.load() (#4392) * fix weights only * change otx version * change Changelog * fix linter * fix checkpoint loading * fix resume * fix exort/testing * try to import error from pickle separately * address bandit issue * add test to check backward compatability * continue add test * add tests * add snapshots * fix unit tests * try full path * try absolute * updated snapshots * bare except * quick test * quick test to see an error * finally fix the error * fix integration test * fix integratiom tests for SS, Anomaly and CLS * fix instance seg
1 parent f6f81db commit 8286197

File tree

16 files changed

+227
-123
lines changed

16 files changed

+227
-123
lines changed

src/otx/core/metrics/fmeasure.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,10 @@ def f_measure_per_confidence(self) -> None | dict:
761761
@property
762762
def best_confidence_threshold(self) -> float:
763763
"""Returns best confidence threshold as ScoreMetric if exists."""
764+
if isinstance(self._best_confidence_threshold, np.float_):
765+
# Convert numpy float to python float
766+
self._best_confidence_threshold = self._best_confidence_threshold.item()
767+
764768
if self._best_confidence_threshold is None:
765769
msg = (
766770
"Cannot obtain best_confidence_threshold updated previously. "

src/otx/core/model/anomaly.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
from dataclasses import asdict
78
from typing import TYPE_CHECKING, Any, Sequence, TypeAlias
89

910
import torch
@@ -344,6 +345,10 @@ def get_dummy_input(self, batch_size: int = 1) -> AnomalyModelInputs:
344345
msg = "Wrong anomaly task type"
345346
raise RuntimeError(msg)
346347

348+
@staticmethod
349+
def _dispatch_label_info(*args, **kwargs) -> AnomalyLabelInfo: # noqa: ARG004
350+
return AnomalyLabelInfo()
351+
347352

348353
class AnomalyMixin:
349354
"""Mixin inherited before AnomalibModule to override OTXModel methods."""
@@ -437,9 +442,9 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
437442
# calls Anomalib's on_save_checkpoint
438443
super().on_save_checkpoint(checkpoint) # type: ignore[misc]
439444

440-
checkpoint["label_info"] = self.label_info # type: ignore[attr-defined]
445+
checkpoint["hyper_parameters"]["label_info"] = asdict(self.label_info) # type: ignore[attr-defined]
441446
checkpoint["otx_version"] = __version__
442-
checkpoint["tile_config"] = self.tile_config # type: ignore[attr-defined]
447+
checkpoint["hyper_parameters"]["tile_config"] = asdict(self.tile_config) # type: ignore[attr-defined]
443448

444449
attrs = ["_input_shape", "image_threshold", "pixel_threshold"]
445450
checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs}
@@ -451,3 +456,21 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
451456
if anomaly_attrs := checkpoint.get("anomaly"):
452457
for key, value in anomaly_attrs.items():
453458
setattr(self, key, value)
459+
460+
def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
461+
"""Load state dictionary from checkpoint state dictionary.
462+
463+
If checkpoint's label_info and OTXLitModule's label_info are different,
464+
load_state_pre_hook for smart weight loading will be registered.
465+
"""
466+
self.on_load_checkpoint(ckpt) # type: ignore[misc]
467+
state_dict = ckpt["state_dict"]
468+
if "label_info" in ckpt.get("hyper_parameters", {}):
469+
label_info = self._dispatch_label_info(ckpt["hyper_parameters"]["label_info"]) # type: ignore[attr-defined]
470+
if label_info != self.label_info: # type: ignore[attr-defined]
471+
msg = f"Checkpoint's label_info {label_info} "
472+
"does not match with OTXLitModule's label_info {self.label_info}" # type: ignore[attr-defined]
473+
raise ValueError(
474+
msg,
475+
)
476+
return super().load_state_dict(state_dict, *args, **kwargs) # type: ignore[misc]

src/otx/core/model/base.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import warnings
1414
from abc import abstractmethod
15+
from dataclasses import asdict
1516
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
1617

1718
import numpy as np
@@ -106,13 +107,13 @@ class OTXModel(LightningModule):
106107

107108
def __init__(
108109
self,
109-
label_info: LabelInfoTypes,
110+
label_info: LabelInfoTypes | dict,
110111
input_size: tuple[int, int] | None = None,
111112
optimizer: OptimizerCallable = DefaultOptimizerCallable,
112113
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
113114
metric: MetricCallable = NullMetricCallable,
114115
torch_compile: bool = False,
115-
tile_config: TileConfig = TileConfig(enable_tiler=False),
116+
tile_config: TileConfig | dict = TileConfig(enable_tiler=False),
116117
) -> None:
117118
super().__init__()
118119

@@ -121,21 +122,21 @@ def __init__(
121122
self.input_size = input_size
122123
self.classification_layers: dict[str, dict[str, Any]] = {}
123124
self.model = self._create_model()
124-
self.optimizer_callable = ensure_callable(optimizer)
125-
self.scheduler_callable = ensure_callable(scheduler)
126-
self.metric_callable = ensure_callable(metric)
125+
self.optimizer_callable: OptimizerCallable = ensure_callable(optimizer)
126+
self.scheduler_callable: LRSchedulerCallable = ensure_callable(scheduler)
127+
self.metric_callable: MetricCallable = ensure_callable(metric)
127128

128129
self.torch_compile = torch_compile
129130
self._explain_mode = False
130131

131132
# NOTE: To guarantee immutablility of the default value
133+
if isinstance(tile_config, dict):
134+
tile_config = TileConfig(**tile_config)
132135
self._tile_config = tile_config.clone()
133-
134-
# this line allows to access init params with 'self.hparams' attribute
135-
# also ensures init params will be stored in ckpt
136-
# TODO(vinnamki): Ticket no. 138995: MetricCallable should be saved in the checkpoint
137-
# so that it can retrieve it from the checkpoint
138-
self.save_hyperparameters(logger=False, ignore=["optimizer", "scheduler", "metric"])
136+
self.save_hyperparameters(
137+
logger=False,
138+
ignore=["optimizer", "scheduler", "metric", "label_info", "tile_config"],
139+
)
139140

140141
def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor | None:
141142
"""Step for model training."""
@@ -376,38 +377,42 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
376377
compiled_state_dict = checkpoint["state_dict"]
377378
checkpoint["state_dict"] = remove_state_dict_prefix(compiled_state_dict, "_orig_mod.")
378379
super().on_save_checkpoint(checkpoint)
379-
380-
checkpoint["label_info"] = self.label_info
380+
checkpoint["hyper_parameters"]["label_info"] = asdict(self.label_info)
381381
checkpoint["otx_version"] = __version__
382-
checkpoint["tile_config"] = self.tile_config
382+
checkpoint["hyper_parameters"]["tile_config"] = asdict(self.tile_config)
383+
checkpoint.pop("datamodule_hparams_name", None)
384+
checkpoint.pop(
385+
"datamodule_hyper_parameters",
386+
None,
387+
) # Remove datamodule_hyper_parameters to prevent storing OTX classes
383388

384389
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
385390
"""Callback on loading checkpoint."""
386391
super().on_load_checkpoint(checkpoint)
387-
388-
if ckpt_label_info := checkpoint.get("label_info"):
389-
if isinstance(ckpt_label_info, LabelInfo) and not hasattr(ckpt_label_info, "label_ids"):
390-
# NOTE: This is for backward compatibility
391-
ckpt_label_info = LabelInfo(
392-
label_groups=ckpt_label_info.label_groups,
393-
label_names=ckpt_label_info.label_names,
394-
label_ids=ckpt_label_info.label_names,
395-
)
396-
self._label_info = ckpt_label_info
397-
398-
if ckpt_tile_config := checkpoint.get("tile_config"):
399-
self.tile_config = ckpt_tile_config
392+
hyper_parameters = checkpoint.get("hyper_parameters", None)
393+
if hyper_parameters:
394+
if ckpt_label_info := hyper_parameters.get("label_info"):
395+
self._label_info = self._dispatch_label_info(ckpt_label_info)
396+
if ckpt_tile_config := hyper_parameters.get("tile_config"):
397+
if isinstance(ckpt_tile_config, dict):
398+
ckpt_tile_config = TileConfig(**ckpt_tile_config)
399+
self.tile_config = ckpt_tile_config
400400

401401
def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
402402
"""Load state dict incrementally."""
403403
ckpt_label_info: LabelInfo | None = (
404-
ckpt.get("label_info") if not is_ckpt_from_otx_v1(ckpt) else self.get_ckpt_label_info_v1(ckpt)
404+
ckpt.get("hyper_parameters", {}).get("label_info")
405+
if not is_ckpt_from_otx_v1(ckpt)
406+
else self.get_ckpt_label_info_v1(ckpt)
405407
)
406408

407409
if ckpt_label_info is None:
408410
msg = "Checkpoint should have `label_info`."
409411
raise ValueError(msg, ckpt_label_info)
410412

413+
if isinstance(ckpt_label_info, dict):
414+
ckpt_label_info = LabelInfo(**ckpt_label_info)
415+
411416
if not hasattr(ckpt_label_info, "label_ids"):
412417
msg = "Loading checkpoint from OTX < 2.2.1, label_ids are assigned automatically"
413418
logger.info(msg)
@@ -447,10 +452,10 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
447452
warnings.warn(msg, stacklevel=2)
448453
state_dict = self.load_from_otx_v1_ckpt(ckpt)
449454
elif is_ckpt_for_finetuning(ckpt):
455+
self.on_load_checkpoint(ckpt)
450456
state_dict = ckpt["state_dict"]
451457
else:
452458
state_dict = ckpt
453-
454459
return super().load_state_dict(state_dict, *args, **kwargs)
455460

456461
def load_from_otx_v1_ckpt(self, ckpt: dict[str, Any]) -> dict:
@@ -828,6 +833,11 @@ def get_dummy_input(self, batch_size: int = 1) -> OTXBatchDataEntity:
828833

829834
@staticmethod
830835
def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
836+
if isinstance(label_info, dict):
837+
if "label_ids" not in label_info:
838+
# NOTE: This is for backward compatibility
839+
label_info["label_ids"] = label_info["label_names"]
840+
return LabelInfo(**label_info)
831841
if isinstance(label_info, int):
832842
return LabelInfo.from_num_classes(num_classes=label_info)
833843
if isinstance(label_info, Sequence) and all(isinstance(name, str) for name in label_info):
@@ -837,6 +847,9 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
837847
label_ids=[str(i) for i in range(len(label_info))],
838848
)
839849
if isinstance(label_info, LabelInfo):
850+
if not hasattr(label_info, "label_ids"):
851+
# NOTE: This is for backward compatibility
852+
label_info.label_ids = label_info.label_names
840853
return label_info
841854

842855
raise TypeError(label_info)

src/otx/core/model/classification.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,18 @@ def _convert_pred_entity_to_compute_metric(
478478

479479
@staticmethod
480480
def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
481-
if not isinstance(label_info, HLabelInfo):
482-
raise TypeError(label_info)
483-
484-
return label_info
481+
if isinstance(label_info, dict):
482+
if "label_ids" not in label_info:
483+
# NOTE: This is for backward compatibility
484+
label_info["label_ids"] = label_info["label_names"]
485+
return HLabelInfo(**label_info)
486+
if isinstance(label_info, HLabelInfo):
487+
if not hasattr(label_info, "label_ids"):
488+
# NOTE: This is for backward compatibility
489+
label_info.label_ids = label_info.label_names
490+
return label_info
491+
492+
raise TypeError(label_info)
485493

486494
def get_dummy_input(self, batch_size: int = 1) -> HlabelClsBatchDataEntity:
487495
"""Returns a dummy input for classification OV model."""

src/otx/core/model/instance_segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def _export_parameters(self) -> TaskLevelExportParameters:
283283
# Instance segmentation needs to add empty label to satisfy MAPI wrapper requirements
284284
modified_label_info.label_names.insert(0, "otx_empty_lbl")
285285
modified_label_info.label_ids.insert(0, "None")
286+
modified_label_info.label_groups[0].insert(0, "otx_empty_lbl")
286287

287288
return super()._export_parameters.wrap(
288289
model_type="MaskRCNN",
@@ -762,7 +763,6 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa
762763

763764
def _create_label_info_from_ov_ir(self) -> LabelInfo:
764765
ov_model = self.model.get_model()
765-
766766
if ov_model.has_rt_info(["model_info", "label_info"]):
767767
serialized = ov_model.get_rt_info(["model_info", "label_info"]).value
768768
ir_label_info = LabelInfo.from_json(serialized)

src/otx/core/model/segmentation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ def _convert_pred_entity_to_compute_metric(
193193

194194
@staticmethod
195195
def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
196+
if isinstance(label_info, dict):
197+
if "label_ids" not in label_info:
198+
# NOTE: This is for backward compatibility
199+
label_info["label_ids"] = label_info["label_names"]
200+
return SegLabelInfo(**label_info)
196201
if isinstance(label_info, int):
197202
return SegLabelInfo.from_num_classes(num_classes=label_info)
198203
if isinstance(label_info, Sequence) and all(isinstance(name, str) for name in label_info):
@@ -202,6 +207,9 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
202207
label_ids=[str(i) for i in range(len(label_info))],
203208
)
204209
if isinstance(label_info, SegLabelInfo):
210+
if not hasattr(label_info, "label_ids"):
211+
# NOTE: This is for backward compatibility
212+
label_info.label_ids = label_info.label_names
205213
return label_info
206214

207215
raise TypeError(label_info)

src/otx/core/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def is_ckpt_from_otx_v1(ckpt: dict) -> bool:
2828
Returns:
2929
bool: True means the checkpoint comes from otx1
3030
"""
31-
return "model" in ckpt and ckpt["VERSION"] == 1
31+
return "model" in ckpt and "VERSION" in ckpt and ckpt["VERSION"] == 1
3232

3333

3434
def is_ckpt_for_finetuning(ckpt: dict) -> bool:

0 commit comments

Comments
 (0)