|
29 | 29 | from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap |
30 | 30 | from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig |
31 | 31 | from rectools.types import InternalIdsArray |
32 | | -from rectools.utils.misc import get_class_or_function_full_path, import_object |
| 32 | +from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict |
33 | 33 |
|
34 | 34 | from ..item_net import ( |
35 | 35 | CatFeaturesItemNet, |
@@ -623,20 +623,36 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None: |
623 | 623 | self.__dict__.update(loaded.__dict__) |
624 | 624 |
|
625 | 625 | @classmethod |
626 | | - def load_from_checkpoint(cls, checkpoint_path: tp.Union[str, Path]) -> tpe.Self: |
627 | | - """ |
628 | | - Load model from Lightning checkpoint path. |
| 626 | + def load_from_checkpoint( |
| 627 | + cls, |
| 628 | + checkpoint_path: tp.Union[str, Path], |
| 629 | + map_location: tp.Optional[tp.Union[str, torch.device]] = None, |
| 630 | + model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None, |
| 631 | + ) -> tpe.Self: |
| 632 | + """Load model from Lightning checkpoint path. |
629 | 633 |
|
630 | 634 | Parameters |
631 | 635 | ---------- |
632 | 636 | checkpoint_path: Union[str, Path] |
633 | 637 | Path to checkpoint location. |
634 | | -
|
| 638 | + map_location: Union[str, torch.device], optional |
| 639 | + Target device to load the checkpoint (e.g., 'cpu', 'cuda:0'). |
| 640 | + If None, will use the device the checkpoint was saved on. |
| 641 | + model_params_update: Dict[str, tp.Any], optional |
| 642 | + Contains custom values for checkpoint['hyper_parameters']['model_config']. |
| 643 | + Has to be flattened with 'dot' reducer, before passed. |
| 644 | + You can use this argument to remove training-specific parameters that are not needed anymore. |
| 645 | + e.g. 'get_trainer_func' |
635 | 646 | Returns |
636 | 647 | ------- |
637 | 648 | Model instance. |
638 | 649 | """ |
639 | | - checkpoint = torch.load(checkpoint_path, weights_only=False) |
| 650 | + checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False) |
| 651 | + if model_params_update: |
| 652 | + prev_model_config = checkpoint["hyper_parameters"]["model_config"] |
| 653 | + prev_config_flatten = make_dict_flat(prev_model_config) |
| 654 | + prev_config_flatten.update(model_params_update) |
| 655 | + checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten) |
640 | 656 | loaded = cls._model_from_checkpoint(checkpoint) |
641 | 657 | return loaded |
642 | 658 |
|
|
0 commit comments