Skip to content

Commit ac266a2

Browse files
authored
Merge 2.4.6 to 2.5.0 (Fix checkpoint loading update) (#4475)
* merge changes * fix linter * fix readme * update modules mock * fix unit test * fix tox * create context manager * add snapshot for anomaly * add hlabel snapshot test * minor fix * fix changelog * fix linter
1 parent 4011712 commit ac266a2

File tree

32 files changed

+501
-197
lines changed

32 files changed

+501
-197
lines changed

.github/workflows/code_scan.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
python-version: "3.10"
2727
- name: Freeze dependencies
2828
run: |
29-
pip install '.[docs,base,mmlab,anomaly,transformers]'
29+
pip install '.[docs]'
3030
pip freeze > requirements.txt
3131
- name: Run Trivy Scan (vuln)
3232
uses: aquasecurity/trivy-action@18f2510ee396bbf400402947b394f2dd8c87dbb0 # v0.29.0

.github/workflows/daily.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- task: "multi_class_cls"
2020
- task: "multi_label_cls"
2121
- task: "h_label_cls"
22-
- task: "anomaly_classification"
22+
- task: "anomaly"
2323
- task: "keypoint_detection"
2424
- task: "detection"
2525
- task: "instance_segmentation"

.github/workflows/perf_benchmark_v2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
fail-fast: false
5252
matrix:
5353
include:
54-
- task: anomaly_classification
54+
- task: anomaly
5555
- task: detection
5656
- task: multi_class_cls
5757
- task: multi_label_cls

.github/workflows/pre_merge.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ jobs:
9191
- task: "multi_class_cls"
9292
- task: "multi_label_cls"
9393
- task: "h_label_cls"
94-
- task: "anomaly_classification"
94+
- task: "anomaly"
9595
- task: "keypoint_detection"
9696
- task: "detection"
9797
- task: "instance_segmentation"

.github/workflows/publish.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ jobs:
1919
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
2020
with:
2121
python-version: "3.10"
22+
- name: Install required Python packages
23+
run: |
24+
python -m pip install --upgrade pip
25+
python -m pip install build
2226
- name: Build sdist
2327
run: python -m build --sdist
2428
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0

CHANGELOG.md

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,36 @@
22

33
All notable changes to this project will be documented in this file.
44

5-
## \[Unreleased\]
5+
## \[2.5.0\]
66

77
### Enhancements
88

9+
- Refactor OTXModels
10+
(<https://github.com/open-edge-platform/training_extensions/pull/4241>)
11+
- Introduce Native OTX Engine, refactor folders structure
12+
(<https://github.com/open-edge-platform/training_extensions/pull/4414>),
13+
(<https://github.com/open-edge-platform/training_extensions/pull/4408>),
14+
(<https://github.com/open-edge-platform/training_extensions/pull/4339>)
15+
- Introduce OVEngine
16+
(<https://github.com/open-edge-platform/training_extensions/pull/4374>),
17+
(<https://github.com/open-edge-platform/training_extensions/pull/4436>)
18+
- Refactor OTX DataEntities
19+
(<https://github.com/open-edge-platform/training_extensions/pull/4322>),
20+
(<https://github.com/open-edge-platform/training_extensions/pull/4282>),
21+
(<https://github.com/open-edge-platform/training_extensions/pull/4239>),
22+
(<https://github.com/open-edge-platform/training_extensions/pull/4412>)
23+
- Introduce new preformance benchmark v2
24+
(<https://github.com/open-edge-platform/training_extensions/pull/4435>,
25+
<https://github.com/open-edge-platform/training_extensions/pull/4400>,
26+
<https://github.com/open-edge-platform/training_extensions/pull/4435>)
27+
- Update documentation
28+
(<https://github.com/open-edge-platform/training_extensions/pull/4447>)
929
- Bump OV and NNCF to 2025.2
10-
(https://github.com/open-edge-platform/training_extensions/pull/4423)
30+
(<https://github.com/open-edge-platform/training_extensions/pull/4423>)
1131
- Bump torch to 2.7.0
12-
(https://github.com/open-edge-platform/training_extensions/pull/4361)
32+
(<https://github.com/open-edge-platform/training_extensions/pull/4361>)
1333
- Add model arch name to exported model metadata
14-
(https://github.com/open-edge-platform/training_extensions/pull/4407)
34+
(<https://github.com/open-edge-platform/training_extensions/pull/4407>)
1535

1636
### Bug fixes
1737

@@ -23,12 +43,56 @@ All notable changes to this project will be documented in this file.
2343
(<https://github.com/openvinotoolkit/training_extensions/pull/4300>)
2444
- Fix missing mAP score reporting for instance segmentation
2545
(<https://github.com/open-edge-platform/training_extensions/pull/4364>)
46+
- Provide XPU workarounds for object detection task
47+
(<https://github.com/open-edge-platform/training_extensions/pull/4464>)
2648

2749
### Removed
2850

2951
- Remove Visual Prompting
3052
(<https://github.com/openvinotoolkit/training_extensions/pull/4291>,<https://github.com/open-edge-platform/training_extensions/pull/4370>)
3153

54+
## \[2.4.6\]
55+
56+
### Bug fixes
57+
58+
- Fix label info dispatching
59+
(<https://github.com/open-edge-platform/training_extensions/pull/4443>)
60+
61+
## \[2.4.5\]
62+
63+
### Bug fixes
64+
65+
- Fix UFlow by adding self.\_setup in UFlow model
66+
(<https://github.com/open-edge-platform/training_extensions/pull/4431>)
67+
- Fix loading/saving checkpoints in OTX
68+
(<https://github.com/open-edge-platform/training_extensions/pull/4433>,
69+
<https://github.com/open-edge-platform/training_extensions/pull/4438>)
70+
71+
## \[2.4.4\]
72+
73+
### Bug fixes
74+
75+
- Fix torch.load() to be able to load all OTX custom snapshots
76+
(<https://github.com/open-edge-platform/training_extensions/pull/4392>)
77+
78+
## \[2.4.3\]
79+
80+
### Enhancements
81+
82+
- Bump torch to 2.7.0
83+
84+
## \[2.4.2\]
85+
86+
### Bug fixes
87+
88+
- Fix torchmetrics to 1.6.0
89+
90+
## \[2.4.1\]
91+
92+
### Bug fixes
93+
94+
- Update Datumaro from 1.10.0rc0 to 1.10.0
95+
3296
## \[2.4.0\]
3397

3498
### New features

src/otx/backend/native/engine.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import time
1414
from contextlib import contextmanager
1515
from pathlib import Path
16+
from pickle import UnpicklingError # nosec B403: UnpicklingError is used only for exception handling
1617
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, Literal
1718
from warnings import warn
1819

@@ -267,7 +268,7 @@ def train(
267268
# load the model state from the checkpoint incrementally.
268269
# This means only the model weights are loaded. If there is a mismatch in label_info,
269270
# perform incremental weight loading for the model's classification layer.
270-
ckpt = torch.load(checkpoint)
271+
ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu")
271272
self.model.load_state_dict_incrementally(ckpt)
272273

273274
with override_metric_callable(model=self.model, new_metric_callable=metric) as model:
@@ -342,10 +343,8 @@ def test(
342343
# NOTE, trainer.test takes only lightning based checkpoint.
343344
# So, it can't take the OTX1.x checkpoint.
344345
if checkpoint is not None:
345-
kwargs_user_input: dict[str, Any] = {}
346-
347-
model_cls = model.__class__
348-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
346+
ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu")
347+
model.load_state_dict(ckpt)
349348

350349
if model.label_info != self.datamodule.label_info:
351350
if (
@@ -432,10 +431,8 @@ def predict(
432431
datamodule = datamodule if datamodule is not None else self.datamodule
433432

434433
if checkpoint is not None:
435-
kwargs_user_input: dict[str, Any] = {}
436-
437-
model_cls = model.__class__
438-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
434+
ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu")
435+
model.load_state_dict(ckpt)
439436

440437
if model.label_info != self.datamodule.label_info:
441438
msg = (
@@ -534,20 +531,8 @@ def export(
534531
warn(msg, stacklevel=1)
535532
export_demo_package = False
536533

537-
kwargs_user_input: dict[str, Any] = {}
538-
539-
model_cls = self.model.__class__
540-
if hasattr(self.model, "model_name"):
541-
# NOTE: This is a solution to fix backward compatibility issue.
542-
# If the model has `model_name` attribute, it will be passed to the `load_from_checkpoint` method,
543-
# making sure previous model trained without model_name can be loaded.
544-
kwargs_user_input["model_name"] = self.model.model_name
545-
546-
self._model = model_cls.load_from_checkpoint(
547-
checkpoint_path=checkpoint,
548-
map_location="cpu",
549-
**kwargs_user_input,
550-
)
534+
ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu")
535+
self.model.load_state_dict(ckpt)
551536
self.model.eval()
552537

553538
self.model.explain_mode = explain
@@ -617,10 +602,8 @@ def explain(
617602
datamodule = datamodule if datamodule is not None else self.datamodule
618603

619604
if checkpoint is not None:
620-
kwargs_user_input: dict[str, Any] = {}
621-
622-
model_cls = model.__class__
623-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
605+
ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu")
606+
model.load_state_dict(ckpt)
624607

625608
if model.label_info != self.datamodule.label_info:
626609
msg = (
@@ -706,14 +689,8 @@ def benchmark(
706689
checkpoint = checkpoint if checkpoint is not None else self.checkpoint
707690

708691
if checkpoint is not None:
709-
kwargs_user_input: dict[str, Any] = {}
710-
711-
model_cls = self.model.__class__
712-
self._model = model_cls.load_from_checkpoint(
713-
checkpoint_path=checkpoint,
714-
map_location="cpu",
715-
**kwargs_user_input,
716-
)
692+
ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu")
693+
self.model.load_state_dict(ckpt)
717694
self.model.eval()
718695

719696
def dummy_infer(model: OTXModel, batch_size: int = 1) -> float:
@@ -1090,3 +1067,30 @@ def datamodule(self) -> OTXDataModule:
10901067
def is_supported(model: MODEL, data: DATA) -> bool:
10911068
"""Check if the engine is supported for the given model and data."""
10921069
return bool(isinstance(model, OTXModel) and isinstance(data, OTXDataModule))
1070+
1071+
@staticmethod
1072+
def _load_model_checkpoint(checkpoint: PathLike, map_location: str | None = None) -> dict[str, Any]:
1073+
"""Load model checkpoint from the given path.
1074+
1075+
Args:
1076+
checkpoint (PathLike): Path to the checkpoint file.
1077+
1078+
Returns:
1079+
dict[str, Any]: The loaded state dictionary from the checkpoint.
1080+
"""
1081+
if not Path(checkpoint).exists():
1082+
msg = f"Checkpoint file does not exist: {checkpoint}"
1083+
raise FileNotFoundError(msg)
1084+
1085+
try:
1086+
ckpt = torch.load(checkpoint, map_location=map_location)
1087+
except UnpicklingError:
1088+
from otx.backend.native.utils.utils import mock_modules_for_chkpt
1089+
1090+
with mock_modules_for_chkpt():
1091+
ckpt = torch.load(checkpoint, map_location=map_location, weights_only=False)
1092+
except Exception as e:
1093+
msg = f"Failed to load checkpoint from {checkpoint}. Please check the file."
1094+
raise RuntimeError(e) from None
1095+
1096+
return ckpt

src/otx/backend/native/models/anomaly/base.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from dataclasses import asdict
89
from typing import TYPE_CHECKING, Any, Sequence
910

1011
import torch
@@ -283,6 +284,10 @@ def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch:
283284
masks=torch.tensor(0),
284285
)
285286

287+
@staticmethod
288+
def _dispatch_label_info(*args, **kwargs) -> AnomalyLabelInfo: # noqa: ARG004
289+
return AnomalyLabelInfo()
290+
286291

287292
class AnomalyMixin:
288293
"""Mixin inherited before AnomalibModule to override OTXModel methods."""
@@ -376,9 +381,9 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
376381
# calls Anomalib's on_save_checkpoint
377382
super().on_save_checkpoint(checkpoint) # type: ignore[misc]
378383

379-
checkpoint["label_info"] = self.label_info # type: ignore[attr-defined]
384+
checkpoint["hyper_parameters"]["label_info"] = asdict(self.label_info) # type: ignore[attr-defined]
380385
checkpoint["otx_version"] = __version__
381-
checkpoint["tile_config"] = self.tile_config # type: ignore[attr-defined]
386+
checkpoint["hyper_parameters"]["tile_config"] = asdict(self.tile_config) # type: ignore[attr-defined]
382387

383388
attrs = ["_input_shape", "image_threshold", "pixel_threshold"]
384389
checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs}
@@ -390,3 +395,25 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
390395
if anomaly_attrs := checkpoint.get("anomaly"):
391396
for key, value in anomaly_attrs.items():
392397
setattr(self, key, value)
398+
399+
def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
400+
"""Bypass OTXModel's load_state_dict_incrementally."""
401+
return self.load_state_dict(ckpt, *args, **kwargs)
402+
403+
def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
404+
"""Load state dictionary from checkpoint state dictionary.
405+
406+
If checkpoint's label_info and OTXLitModule's label_info are different,
407+
load_state_pre_hook for smart weight loading will be registered.
408+
"""
409+
self.on_load_checkpoint(ckpt) # type: ignore[misc]
410+
state_dict = ckpt["state_dict"]
411+
if "label_info" in ckpt.get("hyper_parameters", {}):
412+
label_info = self._dispatch_label_info(ckpt["hyper_parameters"]["label_info"]) # type: ignore[attr-defined]
413+
if label_info != self.label_info: # type: ignore[attr-defined]
414+
msg = f"Checkpoint's label_info {label_info} "
415+
"does not match with OTXLitModule's label_info {self.label_info}" # type: ignore[attr-defined]
416+
raise ValueError(
417+
msg,
418+
)
419+
return super().load_state_dict(state_dict, *args, **kwargs) # type: ignore[misc]

src/otx/backend/native/models/anomaly/uflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ def __init__(
6363
affine_subnet_channels_ratio=affine_subnet_channels_ratio,
6464
permute_soft=permute_soft,
6565
)
66+
self._setup()

0 commit comments

Comments
 (0)