|
14 | 14 |
|
15 | 15 | import typing as tp |
16 | 16 | from tempfile import NamedTemporaryFile |
| 17 | +from unittest.mock import MagicMock |
17 | 18 |
|
18 | 19 | import pytest |
19 | 20 | from implicit.als import AlternatingLeastSquares |
|
26 | 27 | except ImportError: |
27 | 28 | LightFM = object # it's ok in case we're skipping the tests |
28 | 29 |
|
29 | | - |
30 | 30 | from rectools.metrics import NDCG |
31 | 31 | from rectools.models import ( |
32 | 32 | DSSMModel, |
|
39 | 39 | PopularModel, |
40 | 40 | load_model, |
41 | 41 | model_from_config, |
| 42 | + model_from_params, |
| 43 | + serialization, |
42 | 44 | ) |
43 | 45 | from rectools.models.base import ModelBase, ModelConfig |
44 | 46 | from rectools.models.vector import VectorModel |
| 47 | +from rectools.utils.config import BaseConfig |
45 | 48 |
|
46 | 49 | from .utils import get_successors |
47 | 50 |
|
@@ -77,20 +80,26 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None: |
77 | 80 | assert isinstance(loaded_model, model_cls) |
78 | 81 |
|
79 | 82 |
|
| 83 | +class CustomModelSubConfig(BaseConfig): |
| 84 | + x: int = 10 |
| 85 | + |
| 86 | + |
80 | 87 | class CustomModelConfig(ModelConfig): |
81 | 88 | some_param: int = 1 |
| 89 | + sc: CustomModelSubConfig = CustomModelSubConfig() |
82 | 90 |
|
83 | 91 |
|
84 | 92 | class CustomModel(ModelBase[CustomModelConfig]): |
85 | 93 | config_class = CustomModelConfig |
86 | 94 |
|
87 | | - def __init__(self, some_param: int = 1, verbose: int = 0): |
| 95 | + def __init__(self, some_param: int = 1, x: int = 10, verbose: int = 0): |
88 | 96 | super().__init__(verbose=verbose) |
89 | 97 | self.some_param = some_param |
| 98 | + self.x = x |
90 | 99 |
|
91 | 100 | @classmethod |
92 | 101 | def _from_config(cls, config: CustomModelConfig) -> "CustomModel": |
93 | | - return cls(some_param=config.some_param, verbose=config.verbose) |
| 102 | + return cls(some_param=config.some_param, x=config.sc.x, verbose=config.verbose) |
94 | 103 |
|
95 | 104 |
|
96 | 105 | class TestModelFromConfig: |
@@ -119,6 +128,7 @@ def test_custom_model_creation(self, config: tp.Union[dict, CustomModelConfig]) |
119 | 128 | model = model_from_config(config) |
120 | 129 | assert isinstance(model, CustomModel) |
121 | 130 | assert model.some_param == 2 |
| 131 | + assert model.x == 10 |
122 | 132 |
|
123 | 133 | @pytest.mark.parametrize("simple_types", (False, True)) |
124 | 134 | def test_fails_on_missing_cls(self, simple_types: bool) -> None: |
@@ -177,3 +187,15 @@ def test_fails_on_model_cls_without_from_config_support(self, model_cls: tp.Any) |
177 | 187 | config = {"cls": model_cls} |
178 | 188 | with pytest.raises(NotImplementedError, match="`from_config` method is not implemented for `DSSMModel` model"): |
179 | 189 | model_from_config(config) |
| 190 | + |
| 191 | + |
| 192 | +class TestModelFromParams: |
| 193 | + def test_uses_from_config(self, mocker: MagicMock) -> None: |
| 194 | + params = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc.x": 20} |
| 195 | + spy = mocker.spy(serialization, "model_from_config") |
| 196 | + model = model_from_params(params) |
| 197 | + expected_config = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc": {"x": 20}} |
| 198 | + spy.assert_called_once_with(expected_config) |
| 199 | + assert isinstance(model, CustomModel) |
| 200 | + assert model.some_param == 2 |
| 201 | + assert model.x == 20 |
0 commit comments