Skip to content

Commit 589c7ca

Browse files
authored
Feature/correct splitter (#288)
Fixed LasNSplitter
1 parent 3e8db67 commit 589c7ca

File tree

4 files changed

+57
-6
lines changed

4 files changed

+57
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
### Added
1010
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))
1111

12+
### Fixed
13+
- [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288))
14+
1215
## [0.14.0] - 16.05.2025
1316

1417
### Added

rectools/model_selection/last_n_split.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 MTS (Mobile Telesystems)
1+
# Copyright 2023-2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""LastNSplitter."""
1615

1716
import typing as tp
1817

@@ -103,8 +102,11 @@ def _split_without_filter(
103102
df = interactions.df
104103
idx = pd.RangeIndex(0, len(df))
105104

106-
# last event - rank=1
107-
inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False)
105+
# Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique
106+
grouped = df.groupby(Columns.User)
107+
time_order = grouped[Columns.Datetime].rank(method="first", ascending=True).astype(int)
108+
n_interactions = grouped[Columns.User].transform("size").astype(int)
109+
inv_ranks = n_interactions - time_order + 1
108110

109111
for i_split in range(self.n_splits)[::-1]:
110112
min_rank = i_split * self.n # excluded

rectools/models/nn/item_net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def from_dataset_schema(
230230
@property
231231
def out_dim(self) -> int:
232232
"""Return categorical item embedding output dimension."""
233-
return self.embedding_bag.embedding_dim
233+
return int(self.embedding_bag.embedding_dim)
234234

235235

236236
class IdEmbeddingsItemNet(ItemNetBase):

tests/model_selection/test_last_n_split.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 MTS (Mobile Telesystems)
1+
# Copyright 2023-2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -41,6 +41,52 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]:
4141

4242
return _shuffle
4343

44+
@pytest.fixture
45+
def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions:
46+
df = pd.DataFrame(
47+
[
48+
[1, 1, 1, "2021-09-01"], # 0
49+
[1, 2, 1, "2021-09-02"], # 1
50+
[1, 1, 1, "2021-09-03"], # 2
51+
[1, 2, 1, "2021-09-04"], # 3
52+
[1, 3, 1, "2021-09-05"], # 4
53+
[2, 3, 1, "2021-09-05"], # 5
54+
[2, 2, 1, "2021-08-20"], # 6
55+
[2, 2, 1, "2021-09-06"], # 7
56+
[3, 1, 1, "2021-09-05"], # 8
57+
[1, 6, 1, "2021-09-05"], # 9
58+
],
59+
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
60+
).astype({Columns.Datetime: "datetime64[ns]"})
61+
return Interactions(df)
62+
63+
@pytest.mark.parametrize(
64+
"swap_targets,expected_test_ids, target_item",
65+
(
66+
(False, {9, 7, 8}, 6),
67+
(True, {9, 7, 8}, 3),
68+
),
69+
)
70+
def test_correct_last_interactions(
71+
self,
72+
interactions_equal_timestamps: Interactions,
73+
swap_targets: bool,
74+
expected_test_ids: tp.Set[int],
75+
target_item: int,
76+
) -> None:
77+
# Do not using shuffle fixture, otherwise no valid answers
78+
interactions_et = interactions_equal_timestamps
79+
splitter = LastNSplitter(1, 1, False, False, False)
80+
if swap_targets:
81+
df_swap = interactions_equal_timestamps.df
82+
df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]]
83+
interactions_et = Interactions(df_swap)
84+
loo_split = list(splitter.split(interactions_et, collect_fold_stats=True))
85+
target_ids = loo_split[0][1]
86+
assert set(target_ids) == expected_test_ids
87+
assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids
88+
assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item])
89+
4490
@pytest.fixture
4591
def interactions(self, shuffle_arr: np.ndarray) -> Interactions:
4692
df = pd.DataFrame(

0 commit comments

Comments
 (0)