Skip to content

Commit d85f5da

Browse files
authored
Fix applying model's hparams when loading model from checkpoint (#4057)
1 parent 7b07e6b commit d85f5da

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ All notable changes to this project will be documented in this file.
100100
(<https://github.com/openvinotoolkit/training_extensions/pull/4056>)
101101
- Upgrade MAPI in 2.2
102102
(<https://github.com/openvinotoolkit/training_extensions/pull/4052>)
103+
- Fix applying model's hparams when loading model from checkpoint
104+
(<https://github.com/openvinotoolkit/training_extensions/pull/4057>)
103105

104106
## \[v2.1.0\]
105107

src/otx/engine/engine.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,14 @@ def test(
367367
# NOTE, trainer.test takes only lightning based checkpoint.
368368
# So, it can't take the OTX1.x checkpoint.
369369
if checkpoint is not None and not is_ir_ckpt:
370+
kwargs_user_input: dict[str, Any] = {}
371+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
372+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
373+
# TODO (sungchul): revisit for better solution
374+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
375+
370376
model_cls = model.__class__
371-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)
377+
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
372378

373379
if model.label_info != self.datamodule.label_info:
374380
if (
@@ -462,8 +468,14 @@ def predict(
462468
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
463469

464470
if checkpoint is not None and not is_ir_ckpt:
471+
kwargs_user_input: dict[str, Any] = {}
472+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
473+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
474+
# TODO (sungchul): revisit for better solution
475+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
476+
465477
model_cls = model.__class__
466-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)
478+
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
467479

468480
if model.label_info != self.datamodule.label_info:
469481
msg = (
@@ -574,11 +586,17 @@ def export(
574586
)
575587

576588
if not is_ir_ckpt:
589+
kwargs_user_input: dict[str, Any] = {}
590+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
591+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
592+
# TODO (sungchul): revisit for better solution
593+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
594+
577595
model_cls = self.model.__class__
578596
self.model = model_cls.load_from_checkpoint(
579597
checkpoint_path=checkpoint,
580598
map_location="cpu",
581-
**self.model.hparams,
599+
**kwargs_user_input,
582600
)
583601
self.model.eval()
584602

@@ -742,8 +760,14 @@ def explain(
742760
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)
743761

744762
if checkpoint is not None and not is_ir_ckpt:
763+
kwargs_user_input: dict[str, Any] = {}
764+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
765+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
766+
# TODO (sungchul): revisit for better solution
767+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
768+
745769
model_cls = model.__class__
746-
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)
770+
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input)
747771

748772
if model.label_info != self.datamodule.label_info:
749773
msg = (
@@ -845,11 +869,17 @@ def benchmark(
845869
)
846870

847871
if not is_ir_ckpt:
872+
kwargs_user_input: dict[str, Any] = {}
873+
if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
874+
# to update user's custom infer_reference_info_root through cli for zero-shot learning
875+
# TODO (sungchul): revisit for better solution
876+
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)
877+
848878
model_cls = self.model.__class__
849879
self.model = model_cls.load_from_checkpoint(
850880
checkpoint_path=checkpoint,
851881
map_location="cpu",
852-
**self.model.hparams,
882+
**kwargs_user_input,
853883
)
854884
elif isinstance(self.model, OVModel):
855885
msg = "To run benchmark on OV model, checkpoint must be specified."

tests/unit/engine/test_engine.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,7 @@ def test_exporting(self, fxt_engine, mocker) -> None:
223223
checkpoint = "path/to/checkpoint.ckpt"
224224
fxt_engine.checkpoint = checkpoint
225225
fxt_engine.export()
226-
mock_load_from_checkpoint.assert_called_once_with(
227-
checkpoint_path=checkpoint,
228-
map_location="cpu",
229-
**fxt_engine.model.hparams,
230-
)
226+
mock_load_from_checkpoint.assert_called_once_with(checkpoint_path=checkpoint, map_location="cpu")
231227
mock_export.assert_called_once_with(
232228
output_dir=Path(fxt_engine.work_dir),
233229
base_name="exported_model",

0 commit comments

Comments
 (0)