Skip to content

Commit a6ec1c1

Browse files
authored
Feature/checkpoints (#281)
Added `map_location` and `model_params_update` arguments for `load_from_checkpoint` method
1 parent 6d03a4c commit a6ec1c1

File tree

4 files changed

+54
-7
lines changed

4 files changed

+54
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111
- Python 3.13 support ([#227](https://github.com/MobileTeleSystems/RecTools/pull/227))
1212
- `fit_partial` implementation for transformer-based models ([#273](https://github.com/MobileTeleSystems/RecTools/pull/273))
13+
- `map_location` and `model_params_update` arguments for the function `load_from_checkpoint` for Transformer-based models. Use `map_location` to explicitly specify the computing new device and `model_params_update` to update original model parameters (e.g. remove training-specific parameters that are not needed anymore) ([#281](https://github.com/MobileTeleSystems/RecTools/pull/281))
1314

1415
## [0.13.0] - 10.04.2025
1516

rectools/models/nn/transformers/base.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
3030
from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig
3131
from rectools.types import InternalIdsArray
32-
from rectools.utils.misc import get_class_or_function_full_path, import_object
32+
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict
3333

3434
from ..item_net import (
3535
CatFeaturesItemNet,
@@ -623,20 +623,36 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
623623
self.__dict__.update(loaded.__dict__)
624624

625625
@classmethod
626-
def load_from_checkpoint(cls, checkpoint_path: tp.Union[str, Path]) -> tpe.Self:
627-
"""
628-
Load model from Lightning checkpoint path.
626+
def load_from_checkpoint(
627+
cls,
628+
checkpoint_path: tp.Union[str, Path],
629+
map_location: tp.Optional[tp.Union[str, torch.device]] = None,
630+
model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None,
631+
) -> tpe.Self:
632+
"""Load model from Lightning checkpoint path.
629633
630634
Parameters
631635
----------
632636
checkpoint_path: Union[str, Path]
633637
Path to checkpoint location.
634-
638+
map_location: Union[str, torch.device], optional
639+
Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
640+
If None, will use the device the checkpoint was saved on.
641+
model_params_update: Dict[str, tp.Any], optional
642+
Contains custom values for checkpoint['hyper_parameters']['model_config'].
643+
Has to be flattened with 'dot' reducer, before passed.
644+
You can use this argument to remove training-specific parameters that are not needed anymore.
645+
e.g. 'get_trainer_func'
635646
Returns
636647
-------
637648
Model instance.
638649
"""
639-
checkpoint = torch.load(checkpoint_path, weights_only=False)
650+
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
651+
if model_params_update:
652+
prev_model_config = checkpoint["hyper_parameters"]["model_config"]
653+
prev_config_flatten = make_dict_flat(prev_model_config)
654+
prev_config_flatten.update(model_params_update)
655+
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
640656
loaded = cls._model_from_checkpoint(checkpoint)
641657
return loaded
642658

tests/models/nn/transformers/test_base.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,37 @@ def test_save_load_for_fitted_model(
154154

155155
@pytest.mark.parametrize("test_dataset", ("dataset", "dataset_item_features"))
156156
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
157+
@pytest.mark.parametrize(
158+
"map_location",
159+
(
160+
"cpu",
161+
pytest.param(
162+
"cuda:0",
163+
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
164+
),
165+
None,
166+
),
167+
)
168+
@pytest.mark.parametrize(
169+
"model_params_update",
170+
(
171+
{
172+
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
173+
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
174+
},
175+
{
176+
"get_val_mask_func": None,
177+
"get_trainer_func": None,
178+
},
179+
None,
180+
),
181+
)
157182
def test_load_from_checkpoint(
158183
self,
159184
model_cls: tp.Type[TransformerModelBase],
160185
test_dataset: str,
186+
map_location: tp.Optional[tp.Union[str, torch.device]],
187+
model_params_update: tp.Optional[tp.Dict[str, tp.Any]],
161188
request: FixtureRequest,
162189
) -> None:
163190

@@ -175,7 +202,9 @@ def test_load_from_checkpoint(
175202
raise ValueError("No log dir")
176203
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
177204
assert os.path.isfile(ckpt_path)
178-
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
205+
recovered_model = model_cls.load_from_checkpoint(
206+
ckpt_path, map_location=map_location, model_params_update=model_params_update
207+
)
179208
assert isinstance(recovered_model, model_cls)
180209

181210
self._assert_same_reco(model, recovered_model, dataset)

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def test_i2i(
451451
whitelist: tp.Optional[np.ndarray],
452452
expected: pd.DataFrame,
453453
) -> None:
454+
454455
model = BERT4RecModel(
455456
n_factors=32,
456457
n_blocks=2,

0 commit comments

Comments
 (0)