|
13 | 13 | import time |
14 | 14 | from contextlib import contextmanager |
15 | 15 | from pathlib import Path |
| 16 | +from pickle import UnpicklingError # nosec B403: UnpicklingError is used only for exception handling |
16 | 17 | from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, Literal |
17 | 18 | from warnings import warn |
18 | 19 |
|
@@ -267,7 +268,7 @@ def train( |
267 | 268 | # load the model state from the checkpoint incrementally. |
268 | 269 | # This means only the model weights are loaded. If there is a mismatch in label_info, |
269 | 270 | # perform incremental weight loading for the model's classification layer. |
270 | | - ckpt = torch.load(checkpoint) |
| 271 | + ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu") |
271 | 272 | self.model.load_state_dict_incrementally(ckpt) |
272 | 273 |
|
273 | 274 | with override_metric_callable(model=self.model, new_metric_callable=metric) as model: |
@@ -342,10 +343,8 @@ def test( |
342 | 343 | # NOTE, trainer.test takes only lightning based checkpoint. |
343 | 344 | # So, it can't take the OTX1.x checkpoint. |
344 | 345 | if checkpoint is not None: |
345 | | - kwargs_user_input: dict[str, Any] = {} |
346 | | - |
347 | | - model_cls = model.__class__ |
348 | | - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) |
| 346 | + ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu") |
| 347 | + model.load_state_dict(ckpt) |
349 | 348 |
|
350 | 349 | if model.label_info != self.datamodule.label_info: |
351 | 350 | if ( |
@@ -432,10 +431,8 @@ def predict( |
432 | 431 | datamodule = datamodule if datamodule is not None else self.datamodule |
433 | 432 |
|
434 | 433 | if checkpoint is not None: |
435 | | - kwargs_user_input: dict[str, Any] = {} |
436 | | - |
437 | | - model_cls = model.__class__ |
438 | | - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) |
| 434 | + ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu") |
| 435 | + model.load_state_dict(ckpt) |
439 | 436 |
|
440 | 437 | if model.label_info != self.datamodule.label_info: |
441 | 438 | msg = ( |
@@ -534,20 +531,8 @@ def export( |
534 | 531 | warn(msg, stacklevel=1) |
535 | 532 | export_demo_package = False |
536 | 533 |
|
537 | | - kwargs_user_input: dict[str, Any] = {} |
538 | | - |
539 | | - model_cls = self.model.__class__ |
540 | | - if hasattr(self.model, "model_name"): |
541 | | - # NOTE: This is a solution to fix backward compatibility issue. |
542 | | - # If the model has `model_name` attribute, it will be passed to the `load_from_checkpoint` method, |
543 | | - # making sure previous model trained without model_name can be loaded. |
544 | | - kwargs_user_input["model_name"] = self.model.model_name |
545 | | - |
546 | | - self._model = model_cls.load_from_checkpoint( |
547 | | - checkpoint_path=checkpoint, |
548 | | - map_location="cpu", |
549 | | - **kwargs_user_input, |
550 | | - ) |
| 534 | + ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu") |
| 535 | + self.model.load_state_dict(ckpt) |
551 | 536 | self.model.eval() |
552 | 537 |
|
553 | 538 | self.model.explain_mode = explain |
@@ -617,10 +602,8 @@ def explain( |
617 | 602 | datamodule = datamodule if datamodule is not None else self.datamodule |
618 | 603 |
|
619 | 604 | if checkpoint is not None: |
620 | | - kwargs_user_input: dict[str, Any] = {} |
621 | | - |
622 | | - model_cls = model.__class__ |
623 | | - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) |
| 605 | + ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu") |
| 606 | + model.load_state_dict(ckpt) |
624 | 607 |
|
625 | 608 | if model.label_info != self.datamodule.label_info: |
626 | 609 | msg = ( |
@@ -706,14 +689,8 @@ def benchmark( |
706 | 689 | checkpoint = checkpoint if checkpoint is not None else self.checkpoint |
707 | 690 |
|
708 | 691 | if checkpoint is not None: |
709 | | - kwargs_user_input: dict[str, Any] = {} |
710 | | - |
711 | | - model_cls = self.model.__class__ |
712 | | - self._model = model_cls.load_from_checkpoint( |
713 | | - checkpoint_path=checkpoint, |
714 | | - map_location="cpu", |
715 | | - **kwargs_user_input, |
716 | | - ) |
| 692 | + ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu") |
| 693 | + self.model.load_state_dict(ckpt) |
717 | 694 | self.model.eval() |
718 | 695 |
|
719 | 696 | def dummy_infer(model: OTXModel, batch_size: int = 1) -> float: |
@@ -1090,3 +1067,30 @@ def datamodule(self) -> OTXDataModule: |
1090 | 1067 | def is_supported(model: MODEL, data: DATA) -> bool: |
1091 | 1068 | """Check if the engine is supported for the given model and data.""" |
1092 | 1069 | return bool(isinstance(model, OTXModel) and isinstance(data, OTXDataModule)) |
| 1070 | + |
| 1071 | + @staticmethod |
| 1072 | + def _load_model_checkpoint(checkpoint: PathLike, map_location: str | None = None) -> dict[str, Any]: |
| 1073 | + """Load model checkpoint from the given path. |
| 1074 | +
|
| 1075 | + Args: |
| 1076 | + checkpoint (PathLike): Path to the checkpoint file. |
| 1077 | +
|
| 1078 | + Returns: |
| 1079 | + dict[str, Any]: The loaded state dictionary from the checkpoint. |
| 1080 | + """ |
| 1081 | + if not Path(checkpoint).exists(): |
| 1082 | + msg = f"Checkpoint file does not exist: {checkpoint}" |
| 1083 | + raise FileNotFoundError(msg) |
| 1084 | + |
| 1085 | + try: |
| 1086 | + ckpt = torch.load(checkpoint, map_location=map_location) |
| 1087 | + except UnpicklingError: |
| 1088 | + from otx.backend.native.utils.utils import mock_modules_for_chkpt |
| 1089 | + |
| 1090 | + with mock_modules_for_chkpt(): |
| 1091 | + ckpt = torch.load(checkpoint, map_location=map_location, weights_only=False) |
| 1092 | + except Exception as e: |
| 1093 | + msg = f"Failed to load checkpoint from {checkpoint}. Please check the file." |
| 1094 | + raise RuntimeError(e) from None |
| 1095 | + |
| 1096 | + return ckpt |
0 commit comments