Skip to content

Commit 061082c

Browse files
authored
feature/loo_mask (#292)
Added leave_one_out_mask function for transformer models training
1 parent af43135 commit 061082c

File tree

7 files changed

+138
-18
lines changed

7 files changed

+138
-18
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## Unreleased
9+
10+
### Added
11+
- `leave_one_out_mask` function (`rectools.models.nn.transformers.utils.leave_one_out_mask`) for applying leave-one-out validation during transformer models training.([#292](https://github.com/MobileTeleSystems/RecTools/pull/292))
12+
813
## [0.15.0] - 17.07.2025
914

1015
### Added
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import typing as tp
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from rectools import Columns, ExternalIds
7+
8+
9+
def leave_one_out_mask(interactions: pd.DataFrame, val_users: tp.Union[ExternalIds, int, None] = None) -> np.ndarray:
10+
"""
11+
Create a boolean mask for leave-one-out validation by selecting the last interaction per user.
12+
13+
Identifies the most recent interaction for specified validation users based on timestamp ranking.
14+
Users can be filtered using `val_users` parameter which supports slicing or explicit user IDs.
15+
16+
Parameters
17+
----------
18+
interactions : pd.DataFrame
19+
User-item interactions data with at least three columns:
20+
Columns.User, Columns.Item and Columns.Datetime
21+
val_users : Optional[Union[ExternalIds, int]], default ``None``
22+
Validation user filter. Can be:
23+
- None: use all users
24+
- int: randomly sample N users from unique user list without replacement
25+
- array-like: explicit list of user IDs to include
26+
27+
Returns
28+
-------
29+
np.ndarray
30+
Boolean array where True indicates the interaction is the last one for its user
31+
in the validation set.
32+
"""
33+
groups = interactions.groupby(Columns.User)
34+
time_order = groups[Columns.Datetime].rank(method="first", ascending=True).astype(int)
35+
n_interactions = groups.transform("size").astype(int)
36+
inv_ranks = n_interactions - time_order
37+
last_interact_mask = inv_ranks == 0
38+
if isinstance(val_users, int):
39+
users = interactions[Columns.User].unique()
40+
val_users = np.random.choice(users, size=val_users, replace=False)
41+
elif val_users is None:
42+
return last_interact_mask.values
43+
44+
mask = interactions[Columns.User].isin(val_users) & last_interact_mask
45+
return mask.values

tests/models/nn/transformers/test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
from rectools.models import BERT4RecModel, SASRecModel, load_model
3232
from rectools.models.nn.transformers.base import TransformerModelBase
3333
from rectools.models.nn.transformers.lightning import TransformerLightningModule
34+
from rectools.models.nn.transformers.utils import leave_one_out_mask
3435
from tests.models.data import INTERACTIONS
3536
from tests.models.utils import assert_save_load_do_not_change_model
3637

37-
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
38+
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt
3839

3940

4041
def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None:
@@ -200,7 +201,7 @@ def test_save_load_for_fitted_model(
200201
"model_params_update",
201202
(
202203
{
203-
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
204+
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
204205
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
205206
},
206207
{

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@
3838
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
3939
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
4040
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
41+
from rectools.models.nn.transformers.utils import leave_one_out_mask
4142
from tests.models.data import DATASET
4243
from tests.models.utils import (
4344
assert_default_config_and_default_model_params_are_the_same,
4445
assert_second_fit_refits_model,
4546
)
4647

47-
from .utils import custom_trainer, leave_one_out_mask
48+
from .utils import custom_trainer
4849

4950

5051
class TestBERT4RecModel:
@@ -1027,7 +1028,7 @@ def test_get_config(
10271028
"data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator",
10281029
"lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule",
10291030
"negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler",
1030-
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
1031+
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
10311032
"similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule",
10321033
"backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone",
10331034
}

tests/models/nn/transformers/test_sasrec.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@
3737
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
3838
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
3939
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
40+
from rectools.models.nn.transformers.utils import leave_one_out_mask
4041
from tests.models.data import DATASET
4142
from tests.models.utils import (
4243
assert_default_config_and_default_model_params_are_the_same,
4344
assert_second_fit_refits_model,
4445
)
4546
from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal
4647

47-
from .utils import custom_trainer, leave_one_out_mask
48+
from .utils import custom_trainer
4849

4950

5051
class TestSASRecModel:
@@ -1017,7 +1018,7 @@ def test_get_config(
10171018
"data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator",
10181019
"lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule",
10191020
"negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler",
1020-
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
1021+
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
10211022
"similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule",
10221023
"backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone",
10231024
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2022-2025 MTS (Mobile Telesystems)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import typing as tp
16+
17+
import numpy as np
18+
import pandas as pd
19+
import pytest
20+
21+
from rectools import Columns
22+
from rectools.dataset import Interactions
23+
from rectools.models.nn.transformers.utils import leave_one_out_mask
24+
25+
26+
class TestLeaveOneOutMask:
27+
def setup_method(self) -> None:
28+
np.random.seed(32)
29+
30+
@pytest.fixture
31+
def interactions(self) -> Interactions:
32+
df = pd.DataFrame(
33+
[
34+
[1, 1, 1, "2021-09-01"], # 0
35+
[1, 2, 1, "2021-09-02"], # 1
36+
[1, 1, 1, "2021-09-03"], # 2
37+
[1, 2, 1, "2021-09-04"], # 3
38+
[1, 3, 1, "2021-09-05"], # 4
39+
[2, 3, 1, "2021-09-06"], # 5
40+
[2, 2, 1, "2021-08-20"], # 6
41+
[2, 2, 1, "2021-09-06"], # 7
42+
[3, 1, 1, "2021-09-05"], # 8
43+
[1, 6, 1, "2021-09-05"], # 9
44+
],
45+
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
46+
).astype({Columns.Datetime: "datetime64[ns]"})
47+
return Interactions(df)
48+
49+
@pytest.mark.parametrize(
50+
"swap_interactions,expected_val_index, expected_val_item, val_users",
51+
(
52+
([9, 9], [7, 8, 9], 6, None),
53+
([9, 9], [7, 8, 9], 6, 3),
54+
([9, 9], [8, 9], 6, 2),
55+
([4, 9], [7, 8, 9], 3, None),
56+
([4, 9], [7, 8, 9], 3, 3),
57+
([4, 9], [8, 9], 3, 2),
58+
([7, 7], [7, 8], 2, [2, 3]),
59+
([5, 7], [7, 8], 3, [2, 3]),
60+
([8, 8], [8], 1, [3]),
61+
),
62+
)
63+
def test_correct_last_interactions(
64+
self,
65+
interactions: Interactions,
66+
swap_interactions: tuple,
67+
expected_val_index: tp.List[int],
68+
expected_val_item: int,
69+
val_users: tp.Optional[tp.List[int]],
70+
) -> None:
71+
interactions_df = interactions.df
72+
swap_revert = swap_interactions[::-1]
73+
interactions_df.iloc[swap_interactions] = interactions_df.iloc[swap_revert]
74+
val_mask = leave_one_out_mask(interactions_df, val_users)
75+
val_interactions = interactions_df[val_mask]
76+
last_index = max(swap_interactions)
77+
78+
assert list(val_interactions.index) == expected_val_index
79+
assert val_interactions.loc[last_index, [Columns.Item]].values[0] == expected_val_item

tests/models/nn/transformers/utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pandas as pd
1615
from pytorch_lightning import Trainer
1716
from pytorch_lightning.callbacks import ModelCheckpoint
1817

19-
from rectools import Columns
20-
21-
22-
def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:
23-
rank = (
24-
interactions.sort_values(Columns.Datetime, ascending=False, kind="stable")
25-
.groupby(Columns.User, sort=False)
26-
.cumcount()
27-
)
28-
return rank == 0
29-
3018

3119
def custom_trainer() -> Trainer:
3220
return Trainer(

0 commit comments

Comments
 (0)