Skip to content

Commit 5664a3c

Browse files
authored
Feature/model from params (#252)
- Added `from_params` method for models and `model_from_params` function
1 parent fa6c201 commit 5664a3c

File tree

8 files changed

+161
-7
lines changed

8 files changed

+161
-7
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
89
## Unreleased
9-
- Add `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229))
10+
11+
### Added
12+
- `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229))
13+
- `from_params` method for models and `model_from_params` function ([#252](https://github.com/MobileTeleSystems/RecTools/pull/252))
14+
1015

1116
## [0.10.0] - 16.01.2025
1217

rectools/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
`models.DSSMModel`
2929
`models.EASEModel`
3030
`models.ImplicitALSWrapperModel`
31+
`models.ImplicitBPRWrapperModel`
3132
`models.ImplicitItemKNNWrapperModel`
3233
`models.LightFMWrapperModel`
3334
`models.PopularModel`
@@ -44,7 +45,7 @@
4445
from .popular_in_category import PopularInCategoryModel
4546
from .pure_svd import PureSVDModel
4647
from .random import RandomModel
47-
from .serialization import load_model, model_from_config
48+
from .serialization import load_model, model_from_config, model_from_params
4849

4950
try:
5051
from .lightfm import LightFMWrapperModel
@@ -70,4 +71,5 @@
7071
"DSSMModel",
7172
"load_model",
7273
"model_from_config",
74+
"model_from_params",
7375
)

rectools/models/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from rectools.exceptions import NotFittedError
3232
from rectools.types import ExternalIdsArray, InternalIdsArray
3333
from rectools.utils.config import BaseConfig
34-
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat
34+
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict
3535
from rectools.utils.serialization import PICKLE_PROTOCOL, FileLike, read_bytes
3636

3737
T = tp.TypeVar("T", bound="ModelBase")
@@ -210,6 +210,26 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self:
210210

211211
return cls._from_config(config_obj)
212212

213+
@classmethod
214+
def from_params(cls, params: tp.Dict[str, tp.Any], sep: str = ".") -> tpe.Self:
215+
"""
216+
Create model from parameters.
217+
Same as `from_config` but accepts flat dict.
218+
219+
Parameters
220+
----------
221+
params : dict
222+
Model parameters as a flat dict with keys separated by `sep`.
223+
sep : str, default "."
224+
Separator for nested keys.
225+
226+
Returns
227+
-------
228+
Model instance.
229+
"""
230+
config_dict = unflatten_dict(params, sep=sep)
231+
return cls.from_config(config_dict)
232+
213233
@classmethod
214234
def _from_config(cls, config: ModelConfig_T) -> tpe.Self:
215235
raise NotImplementedError()

rectools/models/serialization.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pydantic import TypeAdapter
1919

2020
from rectools.models.base import ModelBase, ModelClass, ModelConfig
21+
from rectools.utils.misc import unflatten_dict
2122
from rectools.utils.serialization import FileLike, read_bytes
2223

2324

@@ -46,7 +47,7 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase:
4647
4748
Parameters
4849
----------
49-
config : ModelConfig
50+
config : dict or ModelConfig
5051
Model config.
5152
5253
Returns
@@ -64,3 +65,24 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase:
6465
raise ValueError("`cls` must be provided in the config to load the model")
6566

6667
return model_cls.from_config(config)
68+
69+
70+
def model_from_params(params: dict, sep: str = ".") -> ModelBase:
71+
"""
72+
Create model from dict of parameters.
73+
Same as `from_config` but accepts flat dict.
74+
75+
Parameters
76+
----------
77+
params : dict
78+
Model parameters as a flat dict with keys separated by `sep`.
79+
sep : str, default "."
80+
Separator for nested keys.
81+
82+
Returns
83+
-------
84+
model
85+
Model instance.
86+
"""
87+
config_dict = unflatten_dict(params, sep=sep)
88+
return model_from_config(config_dict)

rectools/utils/misc.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,34 @@ def make_dict_flat(d: tp.Dict[str, tp.Any], sep: str = ".", parent_key: str = ""
228228
else:
229229
items.append((new_key, v))
230230
return dict(items)
231+
232+
233+
def unflatten_dict(d: tp.Dict[str, tp.Any], sep: str = ".") -> tp.Dict[str, tp.Any]:
234+
"""
235+
Convert a flat dict with concatenated keys back into a nested dictionary.
236+
237+
Parameters
238+
----------
239+
d : dict
240+
Flattened dictionary.
241+
sep : str, default "."
242+
Separator used in flattened keys.
243+
244+
Returns
245+
-------
246+
dict
247+
Nested dictionary.
248+
249+
Examples
250+
--------
251+
>>> unflatten_dict({'a.b': 1, 'a.c': 2, 'd': 3})
252+
{'a': {'b': 1, 'c': 2}, 'd': 3}
253+
"""
254+
result: tp.Dict[str, tp.Any] = {}
255+
for key, value in d.items():
256+
parts = key.split(sep)
257+
current = result
258+
for part in parts[:-1]:
259+
current = current.setdefault(part, {})
260+
current[parts[-1]] = value
261+
return result

tests/models/test_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from datetime import timedelta
2020
from pathlib import Path
2121
from tempfile import NamedTemporaryFile, TemporaryFile
22+
from unittest.mock import MagicMock
2223

2324
import numpy as np
2425
import pandas as pd
@@ -498,6 +499,15 @@ def test_from_config_dict_with_extra_keys(self) -> None:
498499
):
499500
self.model_class.from_config(config)
500501

502+
def test_from_params(self, mocker: MagicMock) -> None:
503+
params = {"x": 10, "verbose": 1, "sc.td": "P2DT3H"}
504+
spy = mocker.spy(self.model_class, "from_config")
505+
model = self.model_class.from_params(params)
506+
spy.assert_called_once_with({"x": 10, "verbose": 1, "sc": {"td": "P2DT3H"}})
507+
assert model.x == 10
508+
assert model.td == timedelta(days=2, hours=3)
509+
assert model.verbose == 1
510+
501511
def test_get_config_pydantic(self) -> None:
502512
model = self.model_class(x=10, verbose=1)
503513
config = model.get_config(mode="pydantic")

tests/models/test_serialization.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import typing as tp
1616
from tempfile import NamedTemporaryFile
17+
from unittest.mock import MagicMock
1718

1819
import pytest
1920
from implicit.als import AlternatingLeastSquares
@@ -26,7 +27,6 @@
2627
except ImportError:
2728
LightFM = object # it's ok in case we're skipping the tests
2829

29-
3030
from rectools.metrics import NDCG
3131
from rectools.models import (
3232
DSSMModel,
@@ -39,9 +39,12 @@
3939
PopularModel,
4040
load_model,
4141
model_from_config,
42+
model_from_params,
43+
serialization,
4244
)
4345
from rectools.models.base import ModelBase, ModelConfig
4446
from rectools.models.vector import VectorModel
47+
from rectools.utils.config import BaseConfig
4548

4649
from .utils import get_successors
4750

@@ -77,20 +80,26 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None:
7780
assert isinstance(loaded_model, model_cls)
7881

7982

83+
class CustomModelSubConfig(BaseConfig):
84+
x: int = 10
85+
86+
8087
class CustomModelConfig(ModelConfig):
8188
some_param: int = 1
89+
sc: CustomModelSubConfig = CustomModelSubConfig()
8290

8391

8492
class CustomModel(ModelBase[CustomModelConfig]):
8593
config_class = CustomModelConfig
8694

87-
def __init__(self, some_param: int = 1, verbose: int = 0):
95+
def __init__(self, some_param: int = 1, x: int = 10, verbose: int = 0):
8896
super().__init__(verbose=verbose)
8997
self.some_param = some_param
98+
self.x = x
9099

91100
@classmethod
92101
def _from_config(cls, config: CustomModelConfig) -> "CustomModel":
93-
return cls(some_param=config.some_param, verbose=config.verbose)
102+
return cls(some_param=config.some_param, x=config.sc.x, verbose=config.verbose)
94103

95104

96105
class TestModelFromConfig:
@@ -119,6 +128,7 @@ def test_custom_model_creation(self, config: tp.Union[dict, CustomModelConfig])
119128
model = model_from_config(config)
120129
assert isinstance(model, CustomModel)
121130
assert model.some_param == 2
131+
assert model.x == 10
122132

123133
@pytest.mark.parametrize("simple_types", (False, True))
124134
def test_fails_on_missing_cls(self, simple_types: bool) -> None:
@@ -177,3 +187,15 @@ def test_fails_on_model_cls_without_from_config_support(self, model_cls: tp.Any)
177187
config = {"cls": model_cls}
178188
with pytest.raises(NotImplementedError, match="`from_config` method is not implemented for `DSSMModel` model"):
179189
model_from_config(config)
190+
191+
192+
class TestModelFromParams:
193+
def test_uses_from_config(self, mocker: MagicMock) -> None:
194+
params = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc.x": 20}
195+
spy = mocker.spy(serialization, "model_from_config")
196+
model = model_from_params(params)
197+
expected_config = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc": {"x": 20}}
198+
spy.assert_called_once_with(expected_config)
199+
assert isinstance(model, CustomModel)
200+
assert model.some_param == 2
201+
assert model.x == 20

tests/utils/test_misc.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from rectools.utils.misc import unflatten_dict
2+
3+
4+
class TestUnflattenDict:
5+
def test_empty(self) -> None:
6+
assert unflatten_dict({}) == {}
7+
8+
def test_complex(self) -> None:
9+
flattened = {
10+
"a.b": 1,
11+
"a.c": 2,
12+
"d": 3,
13+
"a.e.f": [10, 20],
14+
}
15+
excepted = {
16+
"a": {"b": 1, "c": 2, "e": {"f": [10, 20]}},
17+
"d": 3,
18+
}
19+
assert unflatten_dict(flattened) == excepted
20+
21+
def test_simple(self) -> None:
22+
flattened = {
23+
"a": 1,
24+
"b": 2,
25+
}
26+
excepted = {
27+
"a": 1,
28+
"b": 2,
29+
}
30+
assert unflatten_dict(flattened) == excepted
31+
32+
def test_non_default_sep(self) -> None:
33+
flattened = {
34+
"a_b": 1,
35+
"a_c": 2,
36+
"d": 3,
37+
}
38+
excepted = {
39+
"a": {"b": 1, "c": 2},
40+
"d": 3,
41+
}
42+
assert unflatten_dict(flattened, sep="_") == excepted

0 commit comments

Comments
 (0)