Skip to content

Commit 374b4ae

Browse files
authored
Fix/pop in cat refit (#163)
- Fixed `PopularInCategoryModel` refit behaviour and `cross-validate` compatibility - Fixed `PopularInCategoryModel` empty category interactions behaviour - Added tests Closes #162
1 parent 0699727 commit 374b4ae

File tree

5 files changed

+44
-13
lines changed

5 files changed

+44
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
### Fixed
2121
- Used the latest version of `lightfm` that allows to install it using `poetry>=1.5.0` ([#141](https://github.com/MobileTeleSystems/RecTools/pull/141))
2222
- Added restriction to `pytorch` version for MacOSX + x86_64 that allows to install it on such platforms ([#142](https://github.com/MobileTeleSystems/RecTools/pull/142))
23+
- `PopularInCategoryModel` fitting for multiple times, `cross_validate` compatibility, behaviour with empty category interactions ([#163](https://github.com/MobileTeleSystems/RecTools/pull/163))
2324

2425

2526
## [0.6.0] - 13.05.2024

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
<a href="https://rectools.readthedocs.io/en/stable/">Documentation</a> |
1717
<a href="https://github.com/MobileTeleSystems/RecTools/tree/main/examples">Examples</a> |
1818
<a href="https://github.com/MobileTeleSystems/RecTools/tree/main/examples/tutorials">Tutorials</a> |
19-
<a href="https://github.com/MobileTeleSystems/RecTools/blob/main/CONTRIBUTING.rst">Contribution Guide</a> |
20-
<a href="https://github.com/MobileTeleSystems/RecTools/releases">Release Notes</a>
19+
<a href="https://github.com/MobileTeleSystems/RecTools/blob/main/CONTRIBUTING.rst">Contributing</a> |
20+
<a href="https://github.com/MobileTeleSystems/RecTools/releases">Releases</a> |
21+
<a href="https://github.com/orgs/MobileTeleSystems/projects/1">Developers Board</a>
2122
</p>
2223

2324
RecTools is an easy-to-use Python library which makes the process of building recommendation systems easier,

rectools/models/popular_in_category.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,19 @@ def _check_category_feature(self, dataset: Dataset) -> None:
160160

161161
def _calc_category_scores(self, dataset: Dataset, interactions: pd.DataFrame) -> None:
162162
scores_dict = {}
163+
empty_columns = []
163164
for column_num in self.category_columns:
164165
item_idx = dataset.item_features.values.getcol(column_num).nonzero()[0] # type: ignore
165-
self.category_interactions[column_num] = interactions[interactions[Columns.Item].isin(item_idx)].copy()
166+
category_interactions = interactions[interactions[Columns.Item].isin(item_idx)]
166167
# Category interactions might be empty
167-
if self.category_interactions[column_num].shape[0] == 0:
168-
self.category_columns.remove(column_num)
168+
if category_interactions.shape[0] == 0:
169+
empty_columns.append(column_num)
169170
else:
171+
self.category_interactions[column_num] = category_interactions.copy()
170172
col, func = self._get_groupby_col_and_agg_func(self.popularity)
171173
scores_dict[column_num] = self.category_interactions[column_num][col].apply(func)
174+
175+
self.category_columns = [col for col in self.category_columns if col not in empty_columns]
172176
self.category_scores = pd.Series(scores_dict).sort_values(ascending=False)
173177

174178
def _define_categories_for_analysis(self) -> None:
@@ -177,7 +181,7 @@ def _define_categories_for_analysis(self) -> None:
177181
self.n_effective_categories = self.n_categories
178182
relevant_categories = self.category_scores.head(self.n_categories).index
179183
self.category_scores = self.category_scores.loc[relevant_categories]
180-
self.category_columns = relevant_categories
184+
self.category_columns = relevant_categories.to_list()
181185
else:
182186
self.n_effective_categories = len(self.category_columns)
183187
warnings.warn(
@@ -188,6 +192,13 @@ def _define_categories_for_analysis(self) -> None:
188192
self.n_effective_categories = len(self.category_columns)
189193

190194
def _fit(self, dataset: Dataset) -> None: # type: ignore
195+
196+
self.category_columns = []
197+
self.category_interactions = {}
198+
self.models = {}
199+
self.category_scores = pd.Series()
200+
self.n_effective_categories = 0
201+
191202
self._check_category_feature(dataset)
192203
interactions = self._filter_interactions(dataset.interactions.df)
193204
self._calc_category_scores(dataset, interactions)

tests/model_selection/test_cross_validate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from rectools.metrics.base import MetricAtK
2929
from rectools.model_selection import LastNSplitter, cross_validate
3030
from rectools.model_selection.cross_validate import _gen_2x_internal_ids_dataset
31-
from rectools.models import ImplicitALSWrapperModel, PopularModel, RandomModel
31+
from rectools.models import ImplicitALSWrapperModel, PopularInCategoryModel, PopularModel, RandomModel
3232
from rectools.models.base import ModelBase
3333
from tests.testing_utils import assert_sparse_matrix_equal
3434

@@ -146,6 +146,7 @@ def setup_method(self) -> None:
146146
[14, "f2", 1],
147147
[11, "f1", "y"],
148148
[11, "f2", 2],
149+
[12, "f1", "y"],
149150
],
150151
columns=["id", "feature", "value"],
151152
)
@@ -247,6 +248,7 @@ def test_happy_path_with_features(self, prefer_warm_inference_over_cold: bool) -
247248

248249
models: tp.Dict[str, ModelBase] = {
249250
"als": ImplicitALSWrapperModel(AlternatingLeastSquares(factors=2, iterations=2, random_state=42)),
251+
"pop_in_cat": PopularInCategoryModel(category_feature="f1", n_categories=2),
250252
}
251253

252254
actual = cross_validate(
@@ -282,7 +284,9 @@ def test_happy_path_with_features(self, prefer_warm_inference_over_cold: bool) -
282284
],
283285
"metrics": [
284286
{"model": "als", "i_split": 0, "precision@2": 0.5, "recall@1": 0.0},
285-
{"model": "als", "i_split": 1, "precision@2": 0.375, "recall@1": 0.25},
287+
{"model": "pop_in_cat", "i_split": 0, "precision@2": 0.5, "recall@1": 0.5},
288+
{"model": "als", "i_split": 1, "precision@2": 0.375, "recall@1": 0.0},
289+
{"model": "pop_in_cat", "i_split": 1, "precision@2": 0.375, "recall@1": 0.25},
286290
],
287291
}
288292

tests/models/test_popular_in_category.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,25 @@ def test_i2i(
422422
actual,
423423
)
424424

425-
def test_second_fit_refits_model(self, dataset: Dataset) -> None:
425+
@pytest.mark.parametrize("popularity", ("mean_weight", "n_users", "n_interactions"))
426+
@pytest.mark.parametrize("category_feature", ("f1", "f2"))
427+
@pytest.mark.parametrize("mixing_strategy", ("group", "rotate"))
428+
@pytest.mark.parametrize("ratio_strategy", ("equal", "proportional"))
429+
@pytest.mark.parametrize("n_categories", (2, None))
430+
def test_second_fit_refits_model(
431+
self,
432+
dataset: Dataset,
433+
popularity: str,
434+
category_feature: str,
435+
mixing_strategy: str,
436+
ratio_strategy: str,
437+
n_categories: tp.Optional[int],
438+
) -> None:
426439
model = PopularInCategoryModel(
427-
category_feature="f2",
428-
popularity="mean_weight",
429-
mixing_strategy="group",
430-
ratio_strategy="proportional",
440+
category_feature=category_feature,
441+
popularity=popularity,
442+
mixing_strategy=mixing_strategy,
443+
ratio_strategy=ratio_strategy,
444+
n_categories=n_categories,
431445
)
432446
assert_second_fit_refits_model(model, dataset)

0 commit comments

Comments
 (0)