Skip to content

Commit ae77f27

Browse files
authored
Removed internal ids usage outside models (#177)
- Removed `assume_external_ids` parameter in `recommend` and `recommend_to_items` model methods - Added `on_unsupported_targets` in `recommend` and `recommend_to_items` model methods - Added `filter_on_interactions_df_row_indexes` method of `Dataset` - Fixed `IntraListDiversity` metric computation in `cross_validate`
1 parent f73d054 commit ae77f27

File tree

8 files changed

+447
-406
lines changed

8 files changed

+447
-406
lines changed

CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
### Added
1212
- `Debias` mechanism for classification, ranking and auc metrics. New parameter `is_debiased` to `calc_from_confusion_df`, `calc_per_user_from_confusion_df` methods of classification metrics, `calc_from_fitted`, `calc_per_user_from_fitted` methods of auc and rankning (`MAP`) metrics, `calc_from_merged`, `calc_per_user_from_merged` methods of ranking (`NDCG`, `MRR`) metrics. ([#152](https://github.com/MobileTeleSystems/RecTools/pull/152))
1313
- `nbformat >= 4.2.0` dependency to `[visuals]` extra ([#169](https://github.com/MobileTeleSystems/RecTools/pull/169))
14+
- `filter_interactions` method of `Dataset` ([#177](https://github.com/MobileTeleSystems/RecTools/pull/177))
15+
- `on_unsupported_targets` parameter to `recommend` and `recommend_to_items` model methods ([#177](https://github.com/MobileTeleSystems/RecTools/pull/177))
1416

1517
### Fixed
1618
- `display()` method in `MetricsApp` ([#169](https://github.com/MobileTeleSystems/RecTools/pull/169))
17-
18-
### Fixed
19+
- `IntraListDiversity` metric computation in `cross_validate` ([#177](https://github.com/MobileTeleSystems/RecTools/pull/177))
1920
- Allow warp-kos loss for LightFMWrapperModel ([#175](https://github.com/MobileTeleSystems/RecTools/pull/175))
2021

22+
### Removed
23+
- [Breaking] `assume_external_ids` parameter in `recommend` and `recommend_to_items` model methods ([#177](https://github.com/MobileTeleSystems/RecTools/pull/177))
24+
2125
## [0.7.0] - 29.07.2024
2226

2327
### Added

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ RecTools is on PyPI, so you can use `pip` to install it.
7979
```
8080
pip install rectools
8181
```
82-
The default version doesn't contain all the dependencies, because some of them are needed only for specific models. Available user extensions are the following:
82+
The default version doesn't contain all the dependencies, because some of them are needed only for specific functionality. Available user extensions are the following:
8383

8484
- `lightfm`: adds wrapper for LightFM model,
8585
- `torch`: adds models based on neural nets,

rectools/dataset/dataset.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import typing as tp
1818

1919
import attr
20+
import numpy as np
2021
import pandas as pd
2122
from scipy import sparse
2223

@@ -245,3 +246,66 @@ def get_raw_interactions(self, include_weight: bool = True, include_datetime: bo
245246
pd.DataFrame
246247
"""
247248
return self.interactions.to_external(self.user_id_map, self.item_id_map, include_weight, include_datetime)
249+
250+
def filter_interactions(
251+
self,
252+
row_indexes_to_keep: np.ndarray,
253+
keep_external_ids: bool = True,
254+
keep_features_for_removed_entities: bool = True,
255+
) -> "Dataset":
256+
"""
257+
Generate filtered dataset that contains only provided `row_indexes_to_keep` from original
258+
dataset interactions dataframe.
259+
Resulting dataset will get new id mapping for both users and items.
260+
261+
Parameters
262+
----------
263+
row_indexes_to_keep : np.ndarray
264+
Original dataset interactions df row indexes that are to be kept
265+
keep_external_ids : bool, default `True`
266+
Whether to keep external ids -> 2x internal ids mapping (default).
267+
Otherwise internal -> 2x internal ids mapping will be created.
268+
keep_features_for_removed_entities : bool, default `True`
269+
Whether to keep all features for users and items that are not hot any more.
270+
271+
Returns
272+
-------
273+
Dataset
274+
Filtered dataset that has only selected interactions, new ids mapping and processed features.
275+
"""
276+
interactions_df = self.interactions.df.iloc[row_indexes_to_keep]
277+
278+
# 1x internal -> 2x internal
279+
user_id_map = IdMap.from_values(interactions_df[Columns.User].values)
280+
item_id_map = IdMap.from_values(interactions_df[Columns.Item].values)
281+
interactions = Interactions.from_raw(interactions_df, user_id_map, item_id_map)
282+
283+
def _handle_features(
284+
features: tp.Optional[Features], target_id_map: IdMap, dataset_id_map: IdMap
285+
) -> tp.Tuple[tp.Optional[Features], IdMap]:
286+
if features is None:
287+
return None, target_id_map
288+
289+
if keep_features_for_removed_entities:
290+
all_features_ids = np.arange(len(features))
291+
target_id_map = target_id_map.add_ids(all_features_ids, raise_if_already_present=False)
292+
293+
needed_ids = target_id_map.get_external_sorted_by_internal()
294+
features = features.take(needed_ids)
295+
return features, target_id_map
296+
297+
user_features_new, user_id_map = _handle_features(self.user_features, user_id_map, self.user_id_map)
298+
item_features_new, item_id_map = _handle_features(self.item_features, item_id_map, self.item_id_map)
299+
300+
if keep_external_ids: # external -> 2x internal
301+
user_id_map = IdMap(self.user_id_map.convert_to_external(user_id_map.external_ids))
302+
item_id_map = IdMap(self.item_id_map.convert_to_external(item_id_map.external_ids))
303+
304+
filtered_dataset = Dataset(
305+
user_id_map=user_id_map,
306+
item_id_map=item_id_map,
307+
interactions=interactions,
308+
user_features=user_features_new,
309+
item_features=item_features_new,
310+
)
311+
return filtered_dataset

rectools/model_selection/cross_validate.py

Lines changed: 29 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,58 +14,16 @@
1414

1515
import typing as tp
1616

17-
import numpy as np
18-
import pandas as pd
19-
2017
from rectools.columns import Columns
21-
from rectools.dataset import Dataset, Features, IdMap, Interactions
18+
from rectools.dataset import Dataset
2219
from rectools.metrics import calc_metrics
2320
from rectools.metrics.base import MetricAtK
24-
from rectools.models.base import ModelBase
21+
from rectools.models.base import ErrorBehaviour, ModelBase
2522
from rectools.types import ExternalIds
2623

2724
from .splitter import Splitter
2825

2926

30-
def _gen_2x_internal_ids_dataset(
31-
interactions_internal_df: pd.DataFrame,
32-
user_features: tp.Optional[Features],
33-
item_features: tp.Optional[Features],
34-
prefer_warm_inference_over_cold: bool,
35-
) -> Dataset:
36-
"""
37-
Make new dataset based on given interactions and features from base dataset.
38-
Assume that interactions dataframe contains internal ids.
39-
Returned dataset contains 2nd level of internal ids.
40-
"""
41-
user_id_map = IdMap.from_values(interactions_internal_df[Columns.User].values) # 1x internal -> 2x internal
42-
item_id_map = IdMap.from_values(interactions_internal_df[Columns.Item].values) # 1x internal -> 2x internal
43-
interactions_train = Interactions.from_raw(interactions_internal_df, user_id_map, item_id_map) # 2x internal
44-
45-
def _handle_features(features: tp.Optional[Features], id_map: IdMap) -> tp.Tuple[tp.Optional[Features], IdMap]:
46-
if features is None:
47-
return None, id_map
48-
49-
if prefer_warm_inference_over_cold:
50-
all_features_ids = np.arange(len(features)) # 1x internal
51-
id_map = id_map.add_ids(all_features_ids, raise_if_already_present=False)
52-
53-
features = features.take(id_map.get_external_sorted_by_internal()) # 2x internal
54-
return features, id_map
55-
56-
user_features_new, user_id_map = _handle_features(user_features, user_id_map)
57-
item_features_new, item_id_map = _handle_features(item_features, item_id_map)
58-
59-
dataset = Dataset(
60-
user_id_map=user_id_map,
61-
item_id_map=item_id_map,
62-
interactions=interactions_train,
63-
user_features=user_features_new,
64-
item_features=item_features_new,
65-
)
66-
return dataset
67-
68-
6927
def cross_validate( # pylint: disable=too-many-locals
7028
dataset: Dataset,
7129
splitter: Splitter,
@@ -77,6 +35,7 @@ def cross_validate( # pylint: disable=too-many-locals
7735
prefer_warm_inference_over_cold: bool = True,
7836
ref_models: tp.Optional[tp.List[str]] = None,
7937
validate_ref_models: bool = False,
38+
on_unsupported_targets: ErrorBehaviour = "warn",
8039
) -> tp.Dict[str, tp.Any]:
8140
"""
8241
Run cross validation on multiple models with multiple metrics.
@@ -113,6 +72,15 @@ def cross_validate( # pylint: disable=too-many-locals
11372
validate_ref_models : bool, default False
11473
If True include models specified in `ref_models` to all metrics calculations
11574
and receive their metrics from cross-validation.
75+
on_unsupported_targets : Literal["raise", "warn", "ignore"], default "warn"
76+
How to handle warm/cold target users when model doesn't support warm/cold inference.
77+
Specify "warn" to filter with warning (default in `cross_validate`).
78+
Specify "ignore" to filter unsupported targets without a warning.
79+
It is highly recommended to pass `CoveredUsers` DQ metric to catch all models with
80+
insufficient recommendations for each fold.
81+
Specify "raise" to raise ValueError in case unsupported targets are passed. In cross-validation
82+
this may cause unexpected errors for some of the complicated models.
83+
11684
11785
Returns
11886
-------
@@ -132,34 +100,26 @@ def cross_validate( # pylint: disable=too-many-locals
132100
]
133101
}
134102
"""
135-
interactions = dataset.interactions
136-
137-
split_iterator = splitter.split(interactions, collect_fold_stats=True)
103+
split_iterator = splitter.split(dataset.interactions, collect_fold_stats=True)
138104

139105
split_infos = []
140106
metrics_all = []
141107

142108
for train_ids, test_ids, split_info in split_iterator:
143109
split_infos.append(split_info)
144110

145-
# ### Prepare split data
146-
interactions_df_train = interactions.df.iloc[train_ids] # 1x internal
147-
# We need to avoid fitting models on sparse matrices with all zero rows/columns =>
148-
# => we need to create a fold dataset which contains only hot users and items for current training
149-
fold_dataset = _gen_2x_internal_ids_dataset(
150-
interactions_df_train, dataset.user_features, dataset.item_features, prefer_warm_inference_over_cold
111+
fold_dataset = dataset.filter_interactions(
112+
row_indexes_to_keep=train_ids,
113+
keep_external_ids=True,
114+
keep_features_for_removed_entities=prefer_warm_inference_over_cold,
151115
)
116+
interactions_df_test = dataset.interactions.df.loc[test_ids]
117+
interactions_df_test[Columns.User] = dataset.user_id_map.convert_to_external(interactions_df_test[Columns.User])
118+
interactions_df_test[Columns.Item] = dataset.item_id_map.convert_to_external(interactions_df_test[Columns.Item])
152119

153-
interactions_df_test = interactions.df.iloc[test_ids] # 1x internal
154-
test_users = interactions_df_test[Columns.User].unique() # 1x internal
155-
catalog = interactions_df_train[Columns.Item].unique() # 1x internal
156-
157-
if items_to_recommend is not None:
158-
item_ids_to_recommend = dataset.item_id_map.convert_to_internal(
159-
items_to_recommend, strict=False
160-
) # 1x internal
161-
else:
162-
item_ids_to_recommend = None
120+
test_users = interactions_df_test[Columns.User].unique()
121+
prev_interactions = fold_dataset.get_raw_interactions()
122+
catalog = prev_interactions[Columns.Item].unique()
163123

164124
# ### Train ref models if any
165125
ref_reco = {}
@@ -171,7 +131,8 @@ def cross_validate( # pylint: disable=too-many-locals
171131
dataset=fold_dataset,
172132
k=k,
173133
filter_viewed=filter_viewed,
174-
items_to_recommend=item_ids_to_recommend,
134+
items_to_recommend=items_to_recommend,
135+
on_unsupported_targets=on_unsupported_targets,
175136
)
176137

177138
# ### Generate recommendations and calc metrics
@@ -183,19 +144,20 @@ def cross_validate( # pylint: disable=too-many-locals
183144
reco = ref_reco[model_name]
184145
else:
185146
model.fit(fold_dataset)
186-
reco = model.recommend( # 1x internal
147+
reco = model.recommend(
187148
users=test_users,
188149
dataset=fold_dataset,
189150
k=k,
190151
filter_viewed=filter_viewed,
191-
items_to_recommend=item_ids_to_recommend,
152+
items_to_recommend=items_to_recommend,
153+
on_unsupported_targets=on_unsupported_targets,
192154
)
193155

194156
metric_values = calc_metrics(
195157
metrics,
196158
reco=reco,
197159
interactions=interactions_df_test,
198-
prev_interactions=interactions_df_train,
160+
prev_interactions=prev_interactions,
199161
catalog=catalog,
200162
ref_reco=ref_reco,
201163
)

0 commit comments

Comments
 (0)