|
| 1 | +# Copyright 2022-2025 MTS (Mobile Telesystems) |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import typing as tp |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +import pandas as pd |
| 19 | +import pytest |
| 20 | + |
| 21 | +from rectools import Columns |
| 22 | +from rectools.dataset import Interactions |
| 23 | +from rectools.models.nn.transformers.utils import leave_one_out_mask |
| 24 | + |
| 25 | + |
| 26 | +class TestLeaveOneOutMask: |
| 27 | + def setup_method(self) -> None: |
| 28 | + np.random.seed(32) |
| 29 | + |
| 30 | + @pytest.fixture |
| 31 | + def interactions(self) -> Interactions: |
| 32 | + df = pd.DataFrame( |
| 33 | + [ |
| 34 | + [1, 1, 1, "2021-09-01"], # 0 |
| 35 | + [1, 2, 1, "2021-09-02"], # 1 |
| 36 | + [1, 1, 1, "2021-09-03"], # 2 |
| 37 | + [1, 2, 1, "2021-09-04"], # 3 |
| 38 | + [1, 3, 1, "2021-09-05"], # 4 |
| 39 | + [2, 3, 1, "2021-09-06"], # 5 |
| 40 | + [2, 2, 1, "2021-08-20"], # 6 |
| 41 | + [2, 2, 1, "2021-09-06"], # 7 |
| 42 | + [3, 1, 1, "2021-09-05"], # 8 |
| 43 | + [1, 6, 1, "2021-09-05"], # 9 |
| 44 | + ], |
| 45 | + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], |
| 46 | + ).astype({Columns.Datetime: "datetime64[ns]"}) |
| 47 | + return Interactions(df) |
| 48 | + |
| 49 | + @pytest.mark.parametrize( |
| 50 | + "swap_interactions,expected_val_index, expected_val_item, val_users", |
| 51 | + ( |
| 52 | + ([9, 9], [7, 8, 9], 6, None), |
| 53 | + ([9, 9], [7, 8, 9], 6, 3), |
| 54 | + ([9, 9], [8, 9], 6, 2), |
| 55 | + ([4, 9], [7, 8, 9], 3, None), |
| 56 | + ([4, 9], [7, 8, 9], 3, 3), |
| 57 | + ([4, 9], [8, 9], 3, 2), |
| 58 | + ([7, 7], [7, 8], 2, [2, 3]), |
| 59 | + ([5, 7], [7, 8], 3, [2, 3]), |
| 60 | + ([8, 8], [8], 1, [3]), |
| 61 | + ), |
| 62 | + ) |
| 63 | + def test_correct_last_interactions( |
| 64 | + self, |
| 65 | + interactions: Interactions, |
| 66 | + swap_interactions: tuple, |
| 67 | + expected_val_index: tp.List[int], |
| 68 | + expected_val_item: int, |
| 69 | + val_users: tp.Optional[tp.List[int]], |
| 70 | + ) -> None: |
| 71 | + interactions_df = interactions.df |
| 72 | + swap_revert = swap_interactions[::-1] |
| 73 | + interactions_df.iloc[swap_interactions] = interactions_df.iloc[swap_revert] |
| 74 | + val_mask = leave_one_out_mask(interactions_df, val_users) |
| 75 | + val_interactions = interactions_df[val_mask] |
| 76 | + last_index = max(swap_interactions) |
| 77 | + |
| 78 | + assert list(val_interactions.index) == expected_val_index |
| 79 | + assert val_interactions.loc[last_index, [Columns.Item]].values[0] == expected_val_item |
0 commit comments