Skip to content

Commit cb370f3

Browse files
authored
Tests/dssm second fit refits (#166)
- Added "second fit refits" test for DSSM
1 parent db27460 commit cb370f3

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

tests/models/test_dssm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from rectools.models import DSSMModel
2626
from rectools.models.dssm import DSSM
2727
from rectools.models.vector import ImplicitRanker
28+
from tests.models.utils import assert_second_fit_refits_model
2829

2930
from .data import INTERACTIONS
3031

@@ -33,6 +34,9 @@
3334
@pytest.mark.filterwarnings("ignore::UserWarning")
3435
class TestDSSMModel:
3536
def setup_method(self) -> None:
37+
self._seed_everything()
38+
39+
def _seed_everything(self) -> None:
3640
seed_everything(42, workers=True)
3741

3842
@pytest.fixture
@@ -330,3 +334,7 @@ def test_raises_when_no_features_in_dataset(self, dataset: Dataset, exclude_feat
330334
model = DSSMModel()
331335
with pytest.raises(ValueError, match="requires user and item features"):
332336
model.fit(dataset)
337+
338+
def test_second_fit_refits_model(self, dataset: Dataset) -> None:
339+
model = DSSMModel(deterministic=True)
340+
assert_second_fit_refits_model(model, dataset, pre_fit_callback=self._seed_everything)

tests/models/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing as tp
1516
from copy import deepcopy
1617

1718
import pandas as pd
@@ -20,9 +21,23 @@
2021
from rectools.models.base import ModelBase
2122

2223

23-
def assert_second_fit_refits_model(model: ModelBase, dataset: Dataset) -> None:
24+
def _dummy_func() -> None:
25+
pass
26+
27+
28+
def assert_second_fit_refits_model(
29+
model: ModelBase, dataset: Dataset, pre_fit_callback: tp.Optional[tp.Callable[[], None]] = None
30+
) -> None:
31+
pre_fit_callback = pre_fit_callback or _dummy_func
32+
33+
pre_fit_callback()
2434
model_1 = deepcopy(model).fit(dataset)
25-
model_2 = deepcopy(model).fit(dataset).fit(dataset)
35+
36+
pre_fit_callback()
37+
model_2 = deepcopy(model).fit(dataset)
38+
pre_fit_callback()
39+
model_2.fit(dataset)
40+
2641
k = dataset.item_id_map.external_ids.size
2742

2843
reco_u2i_1 = model_1.recommend(dataset.user_id_map.external_ids, dataset, k, False)

0 commit comments

Comments
 (0)