|
1 | | -# Copyright 2023 MTS (Mobile Telesystems) |
| 1 | +# Copyright 2023-2025 MTS (Mobile Telesystems) |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
@@ -41,6 +41,52 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]: |
41 | 41 |
|
42 | 42 | return _shuffle |
43 | 43 |
|
| 44 | + @pytest.fixture |
| 45 | + def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions: |
| 46 | + df = pd.DataFrame( |
| 47 | + [ |
| 48 | + [1, 1, 1, "2021-09-01"], # 0 |
| 49 | + [1, 2, 1, "2021-09-02"], # 1 |
| 50 | + [1, 1, 1, "2021-09-03"], # 2 |
| 51 | + [1, 2, 1, "2021-09-04"], # 3 |
| 52 | + [1, 3, 1, "2021-09-05"], # 4 |
| 53 | + [2, 3, 1, "2021-09-05"], # 5 |
| 54 | + [2, 2, 1, "2021-08-20"], # 6 |
| 55 | + [2, 2, 1, "2021-09-06"], # 7 |
| 56 | + [3, 1, 1, "2021-09-05"], # 8 |
| 57 | + [1, 6, 1, "2021-09-05"], # 9 |
| 58 | + ], |
| 59 | + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], |
| 60 | + ).astype({Columns.Datetime: "datetime64[ns]"}) |
| 61 | + return Interactions(df) |
| 62 | + |
| 63 | + @pytest.mark.parametrize( |
| 64 | + "swap_targets,expected_test_ids, target_item", |
| 65 | + ( |
| 66 | + (False, {9, 7, 8}, 6), |
| 67 | + (True, {9, 7, 8}, 3), |
| 68 | + ), |
| 69 | + ) |
| 70 | + def test_correct_last_interactions( |
| 71 | + self, |
| 72 | + interactions_equal_timestamps: Interactions, |
| 73 | + swap_targets: bool, |
| 74 | + expected_test_ids: tp.Set[int], |
| 75 | + target_item: int, |
| 76 | + ) -> None: |
| 77 | + # Do not using shuffle fixture, otherwise no valid answers |
| 78 | + interactions_et = interactions_equal_timestamps |
| 79 | + splitter = LastNSplitter(1, 1, False, False, False) |
| 80 | + if swap_targets: |
| 81 | + df_swap = interactions_equal_timestamps.df |
| 82 | + df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]] |
| 83 | + interactions_et = Interactions(df_swap) |
| 84 | + loo_split = list(splitter.split(interactions_et, collect_fold_stats=True)) |
| 85 | + target_ids = loo_split[0][1] |
| 86 | + assert set(target_ids) == expected_test_ids |
| 87 | + assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids |
| 88 | + assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item]) |
| 89 | + |
44 | 90 | @pytest.fixture |
45 | 91 | def interactions(self, shuffle_arr: np.ndarray) -> Interactions: |
46 | 92 | df = pd.DataFrame( |
|
0 commit comments