Skip to content

Commit 4a4a874

Browse files
authored
Feature/features and filter update (#267)
- Added `normalize` argument to `CatalogCoverage` - Fixed `NDCG` doc - Fixed keeping extra cols in `Dataset.filter` method. Closes #265
1 parent f7160bb commit 4a4a874

File tree

7 files changed

+61
-48
lines changed

7 files changed

+61
-48
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
## Unreleased
1010

1111
### Added
12-
- `CatalogCoverage` metric ([#266](https://github.com/MobileTeleSystems/RecTools/pull/266))
12+
- `CatalogCoverage` metric ([#266](https://github.com/MobileTeleSystems/RecTools/pull/266), [#267](https://github.com/MobileTeleSystems/RecTools/pull/267))
1313
- `divide_by_achievable` argument to `NDCG` metric ([#266](https://github.com/MobileTeleSystems/RecTools/pull/266))
1414

15+
### Changed
16+
- Interactions extra columns are not dropped in `Dataset.filter_interactions` method [#267](https://github.com/MobileTeleSystems/RecTools/pull/267)
17+
1518
## [0.11.0] - 17.02.2025
1619

1720
### Added

rectools/dataset/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,8 @@ def filter_interactions(
401401
# 1x internal -> 2x internal
402402
user_id_map = IdMap.from_values(interactions_df[Columns.User].values)
403403
item_id_map = IdMap.from_values(interactions_df[Columns.Item].values)
404-
interactions = Interactions.from_raw(interactions_df, user_id_map, item_id_map)
404+
# We shouldn't drop extra columns if they are present
405+
interactions = Interactions.from_raw(interactions_df, user_id_map, item_id_map, keep_extra_cols=True)
405406

406407
def _handle_features(
407408
features: tp.Optional[Features], target_id_map: IdMap, dataset_id_map: IdMap

rectools/metrics/catalog.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,29 @@
1616

1717
import typing as tp
1818

19+
import attr
1920
import pandas as pd
2021

2122
from rectools import Columns
2223

2324
from .base import Catalog, MetricAtK
2425

2526

27+
@attr.s
2628
class CatalogCoverage(MetricAtK):
2729
"""
28-
Share of items in catalog that is present in recommendations for all users.
30+
Count (or share) of items from catalog that is present in recommendations for all users.
2931
3032
Parameters
3133
----------
3234
k : int
3335
Number of items at the top of recommendations list that will be used to calculate metric.
36+
normalize: bool, default ``False``
37+
Flag, which says whether to normalize metric or not.
3438
"""
3539

40+
normalize: bool = attr.ib(default=False)
41+
3642
def calc(self, reco: pd.DataFrame, catalog: Catalog) -> float:
3743
"""
3844
Calculate metric value.
@@ -49,7 +55,10 @@ def calc(self, reco: pd.DataFrame, catalog: Catalog) -> float:
4955
float
5056
Value of metric (aggregated for all users).
5157
"""
52-
return reco.loc[reco[Columns.Rank] <= self.k, Columns.Item].nunique() / len(catalog)
58+
res = reco.loc[reco[Columns.Rank] <= self.k, Columns.Item].nunique()
59+
if self.normalize:
60+
return res / len(catalog)
61+
return res
5362

5463

5564
CatalogMetric = CatalogCoverage

rectools/metrics/ranking.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -314,28 +314,27 @@ class NDCG(_RankingMetric):
314314
r"""
315315
Normalized Discounted Cumulative Gain at k (NDCG@k).
316316
317-
Estimates relevance of recommendations taking in account their order.
317+
Estimates relevance of recommendations taking in account their order. `"Discounted Gain"`
318+
means that original item relevance is being discounted based on this
319+
items rank. The closer is item to the top the, the more gain is achieved.
320+
`"Cumulative"` means that all items discounted gains from ``k`` ranks are being summed.
321+
`"Normalized"` means that the actual value of DCG is being divided by the `"Ideal DCG"` (IDCG).
322+
This is the maximum possible value of `DCG@k`, used as normalization coefficient to ensure that
323+
`NDCG@k` values lie in ``[0, 1]``.
318324
319325
.. math::
320326
NDCG@k=\frac{1}{|U|}\sum_{u \in U}\frac{DCG_u@k}{IDCG_u@k}
321327
328+
DCG_u@k = \sum_{i=1}^{k} \frac{rel_u(i)}{log(i + 1)}
329+
322330
where
323-
- :math:`DCG_u@k` is "Discounted Cumulative Gain" at k for user u.
324-
- `"Gain"` stands for relevance of item at position i to user. It equals to ``1`` if this item
325-
is relevant, ``0`` otherwise
326-
- `"Discounted Gain"` means that original item relevance is being discounted based on this
327-
items rank. The closer is item to the top the, the more gain is achieved.
328-
- `"Discounted Cumulative Gain"` means that discounted gains are summed together.
329-
- :math:`IDCG_u@k` is `"Ideal Discounted Cumulative Gain"` at k for user u. This is maximum
330-
possible value of `DCG@k`, used as normalization coefficient to ensure that `NDCG@k`
331-
values lie in ``[0, 1]``.
332-
333-
When `divide_by_achievable` is set to ``False`` (default) `IDCG_u@k` is the same value for all
334-
users and is equal to:
335-
:math:`IDCG_u@k = \sum_{i=1}^{k} \frac{1}{log(i + 1)}`
336-
When `divide_by_achievable` is set to ``True``, the formula for IDCG depends
337-
on number of each user relevant items in the test set. The formula is:
338-
:math:`IDCG_u@k = \sum_{i=1}^{\min (|R(u)|, k)} \frac{1}{log(i + 1)}`
331+
- :math:`IDCG_u@k = \sum_{i=1}^{k} \frac{1}{log(i + 1)}` when `divide_by_achievable` is set
332+
to ``False`` (default).
333+
- :math:`IDCG_u@k = \sum_{i=1}^{\min (|R(u)|, k)} \frac{1}{log(i + 1)}` when
334+
`divide_by_achievable` is set to ``True``.
335+
- :math:`rel_u(i)` is `"Gain"`. Here it is an indicator function, it equals to ``1`` if the
336+
item at rank ``i`` is relevant to user ``u``, ``0`` otherwise.
337+
- :math:`|R_u|` is number of relevant (ground truth) items for user ``u``.
339338
340339
Parameters
341340
----------

tests/dataset/test_dataset.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -362,19 +362,19 @@ def dataset_to_filter(self) -> Dataset:
362362
user_id_map = IdMap.from_values([10, 11, 12, 13, 14])
363363
df = pd.DataFrame(
364364
[
365-
[0, 0, 1, "2021-09-01"],
366-
[4, 2, 1, "2021-09-02"],
367-
[2, 1, 1, "2021-09-02"],
368-
[2, 2, 1, "2021-09-03"],
369-
[3, 2, 1, "2021-09-03"],
370-
[3, 3, 1, "2021-09-03"],
371-
[3, 4, 1, "2021-09-04"],
372-
[1, 2, 1, "2021-09-04"],
373-
[3, 1, 1, "2021-09-05"],
374-
[4, 2, 1, "2021-09-05"],
375-
[3, 3, 1, "2021-09-06"],
365+
[0, 0, 1, "2021-09-01", 1],
366+
[4, 2, 1, "2021-09-02", 1],
367+
[2, 1, 1, "2021-09-02", 1],
368+
[2, 2, 1, "2021-09-03", 1],
369+
[3, 2, 1, "2021-09-03", 1],
370+
[3, 3, 1, "2021-09-03", 1],
371+
[3, 4, 1, "2021-09-04", 1],
372+
[1, 2, 1, "2021-09-04", 1],
373+
[3, 1, 1, "2021-09-05", 1],
374+
[4, 2, 1, "2021-09-05", 1],
375+
[3, 3, 1, "2021-09-06", 1],
376376
],
377-
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
377+
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra"],
378378
).astype({Columns.Datetime: "datetime64[ns]"})
379379
interactions = Interactions(df)
380380
return Dataset(user_id_map, item_id_map, interactions)
@@ -426,12 +426,12 @@ def test_filter_dataset_interactions_df_rows_without_features(
426426
)
427427
expected_interactions_2x_internal_df = pd.DataFrame(
428428
[
429-
[0, 0, 1, "2021-09-01"],
430-
[1, 1, 1, "2021-09-02"],
431-
[2, 2, 1, "2021-09-02"],
432-
[2, 1, 1, "2021-09-03"],
429+
[0, 0, 1, "2021-09-01", 1],
430+
[1, 1, 1, "2021-09-02", 1],
431+
[2, 2, 1, "2021-09-02", 1],
432+
[2, 1, 1, "2021-09-03", 1],
433433
],
434-
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
434+
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra"],
435435
).astype({Columns.Datetime: "datetime64[ns]", Columns.Weight: float})
436436
np.testing.assert_equal(filtered_dataset.user_id_map.external_ids, expected_external_user_ids)
437437
np.testing.assert_equal(filtered_dataset.item_id_map.external_ids, expected_external_item_ids)
@@ -464,12 +464,12 @@ def test_filter_dataset_interactions_df_rows_with_features(
464464
)
465465
expected_interactions_2x_internal_df = pd.DataFrame(
466466
[
467-
[0, 0, 1, "2021-09-01"],
468-
[1, 1, 1, "2021-09-02"],
469-
[2, 2, 1, "2021-09-02"],
470-
[2, 1, 1, "2021-09-03"],
467+
[0, 0, 1, "2021-09-01", 1],
468+
[1, 1, 1, "2021-09-02", 1],
469+
[2, 2, 1, "2021-09-02", 1],
470+
[2, 1, 1, "2021-09-03", 1],
471471
],
472-
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
472+
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra"],
473473
).astype({Columns.Datetime: "datetime64[ns]", Columns.Weight: float})
474474
np.testing.assert_equal(filtered_dataset.user_id_map.external_ids, expected_external_user_ids)
475475
np.testing.assert_equal(filtered_dataset.item_id_map.external_ids, expected_external_item_ids)

tests/metrics/test_catalog.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
import numpy as np
1818
import pandas as pd
19+
import pytest
1920

2021
from rectools import Columns
2122
from rectools.metrics import CatalogCoverage
2223

2324

2425
class TestCatalogCoverage:
2526
def setup_method(self) -> None:
26-
self.metric = CatalogCoverage(k=2)
2727
self.reco = pd.DataFrame(
2828
{
2929
Columns.User: [1, 1, 1, 2, 2, 3, 4],
@@ -32,7 +32,8 @@ def setup_method(self) -> None:
3232
}
3333
)
3434

35-
def test_calc(self) -> None:
35+
@pytest.mark.parametrize("normalize,expected", ((True, 0.4), (False, 2.0)))
36+
def test_calc(self, normalize: bool, expected: float) -> None:
3637
catalog = np.arange(5)
37-
expected = 0.4
38-
assert self.metric.calc(self.reco, catalog) == expected
38+
metric = CatalogCoverage(k=2, normalize=normalize)
39+
assert metric.calc(self.reco, catalog) == expected

tests/metrics/test_scoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_success(self) -> None:
119119
"sufficient": SufficientReco(k=2),
120120
"unrepeated": UnrepeatedReco(k=2),
121121
"covered_users": CoveredUsers(k=2),
122-
"catalog_coverage": CatalogCoverage(k=2),
122+
"catalog_coverage": CatalogCoverage(k=2, normalize=True),
123123
}
124124
with pytest.warns(UserWarning, match="Custom metrics are not supported"):
125125
actual = calc_metrics(

0 commit comments

Comments
 (0)