Skip to content

Commit f7160bb

Browse files
authored
Feature/idcg and coverage (#266)
- Added `CatalogCoverage` - Added `divide_by_achievable` argument for `NDCG`
1 parent e8728b3 commit f7160bb

File tree

8 files changed

+213
-33
lines changed

8 files changed

+213
-33
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
9+
## Unreleased
10+
11+
### Added
12+
- `CatalogCoverage` metric ([#266](https://github.com/MobileTeleSystems/RecTools/pull/266))
13+
- `divide_by_achievable` argument to `NDCG` metric ([#266](https://github.com/MobileTeleSystems/RecTools/pull/266))
14+
815
## [0.11.0] - 17.02.2025
916

1017
### Added

rectools/metrics/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022-2024 MTS (Mobile Telesystems)
1+
# Copyright 2022-2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -40,6 +40,7 @@
4040
`metrics.SufficientReco`
4141
`metrics.UnrepeatedReco`
4242
`metrics.CoveredUsers`
43+
`metrics.CatalogCoverage`
4344
4445
Tools
4546
-----
@@ -52,6 +53,7 @@
5253
"""
5354

5455
from .auc import PAP, PartialAUC
56+
from .catalog import CatalogCoverage
5557
from .classification import MCC, Accuracy, F1Beta, HitRate, Precision, Recall
5658
from .debias import DebiasConfig, debias_interactions
5759
from .distances import (
@@ -80,6 +82,7 @@
8082
"PartialAUC",
8183
"PAP",
8284
"MRR",
85+
"CatalogCoverage",
8386
"MeanInvUserFreq",
8487
"IntraListDiversity",
8588
"AvgRecPopularity",

rectools/metrics/catalog.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 MTS (Mobile Telesystems)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Catalog statistics recommendations metrics."""
16+
17+
import typing as tp
18+
19+
import pandas as pd
20+
21+
from rectools import Columns
22+
23+
from .base import Catalog, MetricAtK
24+
25+
26+
class CatalogCoverage(MetricAtK):
27+
"""
28+
Share of items in catalog that is present in recommendations for all users.
29+
30+
Parameters
31+
----------
32+
k : int
33+
Number of items at the top of recommendations list that will be used to calculate metric.
34+
"""
35+
36+
def calc(self, reco: pd.DataFrame, catalog: Catalog) -> float:
37+
"""
38+
Calculate metric value.
39+
40+
Parameters
41+
----------
42+
reco : pd.DataFrame
43+
Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`.
44+
catalog : collection
45+
Collection of unique item ids that could be used for recommendations.
46+
47+
Returns
48+
-------
49+
float
50+
Value of metric (aggregated for all users).
51+
"""
52+
return reco.loc[reco[Columns.Rank] <= self.k, Columns.Item].nunique() / len(catalog)
53+
54+
55+
CatalogMetric = CatalogCoverage
56+
57+
58+
def calc_catalog_metrics(
59+
metrics: tp.Dict[str, CatalogMetric],
60+
reco: pd.DataFrame,
61+
catalog: Catalog,
62+
) -> tp.Dict[str, float]:
63+
"""
64+
Calculate metrics of catalog statistics for recommendations.
65+
66+
Warning: It is not recommended to use this function directly.
67+
Use `calc_metrics` instead.
68+
69+
Parameters
70+
----------
71+
metrics : dict(str -> CatalogMetric)
72+
Dict of metric objects to calculate,
73+
where key is a metric name and value is a metric object.
74+
reco : pd.DataFrame
75+
Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`.
76+
catalog : collection
77+
Collection of unique item ids that could be used for recommendations.
78+
79+
Returns
80+
-------
81+
dict(str->float)
82+
Dictionary where keys are the same as keys in `metrics`
83+
and values are metric calculation results.
84+
"""
85+
return {metric_name: metric.calc(reco, catalog) for metric_name, metric in metrics.items()}

rectools/metrics/ranking.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022-2024 MTS (Mobile Telesystems)
1+
# Copyright 2022-2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -317,27 +317,37 @@ class NDCG(_RankingMetric):
317317
Estimates relevance of recommendations taking in account their order.
318318
319319
.. math::
320-
NDCG@k = DCG@k / IDCG@k
321-
where :math:`DCG@k = \sum_{i=1}^{k+1} rel(i) / log_{}(i+1)` -
322-
Discounted Cumulative Gain at k, main part of `NDCG@k`.
320+
NDCG@k=\frac{1}{|U|}\sum_{u \in U}\frac{DCG_u@k}{IDCG_u@k}
323321
324-
The closer it is to the top the more weight it assigns to relevant items.
325-
Here:
326-
- `rel(i)` is an indicator function, it equals to ``1``
327-
if an item at rank `i` is relevant, ``0`` otherwise;
328-
- `log` - logarithm at any given base, usually ``2``.
329-
330-
and :math:`IDCG@k = \sum_{i=1}^{k+1} (1 / log(i + 1))` -
331-
`Ideal DCG@k`, maximum possible value of `DCG@k`, used as
332-
normalization coefficient to ensure that `NDCG@k` values
333-
lie in ``[0, 1]``.
322+
where
323+
- :math:`DCG_u@k` is "Discounted Cumulative Gain" at k for user u.
324+
- `"Gain"` stands for relevance of item at position i to user. It equals to ``1`` if this item
325+
is relevant, ``0`` otherwise
326+
- `"Discounted Gain"` means that original item relevance is being discounted based on this
327+
items rank. The closer is item to the top the, the more gain is achieved.
328+
- `"Discounted Cumulative Gain"` means that discounted gains are summed together.
329+
- :math:`IDCG_u@k` is `"Ideal Discounted Cumulative Gain"` at k for user u. This is maximum
330+
possible value of `DCG@k`, used as normalization coefficient to ensure that `NDCG@k`
331+
values lie in ``[0, 1]``.
332+
333+
When `divide_by_achievable` is set to ``False`` (default) `IDCG_u@k` is the same value for all
334+
users and is equal to:
335+
:math:`IDCG_u@k = \sum_{i=1}^{k} \frac{1}{log(i + 1)}`
336+
When `divide_by_achievable` is set to ``True``, the formula for IDCG depends
337+
on number of each user relevant items in the test set. The formula is:
338+
:math:`IDCG_u@k = \sum_{i=1}^{\min (|R(u)|, k)} \frac{1}{log(i + 1)}`
334339
335340
Parameters
336341
----------
337342
k : int
338343
Number of items at the top of recommendations list that will be used to calculate metric.
339344
log_base : int, default ``2``
340345
Base of logarithm used to weight relevant items.
346+
divide_by_achievable: bool, default ``False``
347+
When set to ``False`` (default) IDCG is calculated as one value for all of the users and
348+
equals to the maximum gain, achievable when all ``k`` positions are relevant.
349+
When set to ``True``, IDCG is calculated for each user individually, considering
350+
the maximum possible amount of user test items on top ``k`` positions.
341351
debias_config : DebiasConfig, optional, default None
342352
Config with debias method parameters (iqr_coef, random_state).
343353
@@ -368,6 +378,7 @@ class NDCG(_RankingMetric):
368378
"""
369379

370380
log_base: int = attr.ib(default=2)
381+
divide_by_achievable: bool = attr.ib(default=False)
371382

372383
def calc_per_user(self, reco: pd.DataFrame, interactions: pd.DataFrame) -> pd.Series:
373384
"""
@@ -429,15 +440,36 @@ def calc_per_user_from_merged(self, merged: pd.DataFrame, is_debiased: bool = Fa
429440
if not is_debiased and self.debias_config is not None:
430441
merged = debias_interactions(merged, self.debias_config)
431442

432-
dcg = (merged[Columns.Rank] <= self.k).astype(int) / log_at_base(merged[Columns.Rank] + 1, self.log_base)
433-
idcg = (1 / log_at_base(np.arange(1, self.k + 1) + 1, self.log_base)).sum()
434-
ndcg = (
435-
pd.DataFrame({Columns.User: merged[Columns.User], "__ndcg": dcg / idcg})
436-
.groupby(Columns.User, sort=False)["__ndcg"]
437-
.sum()
438-
.rename(None)
443+
# DCG
444+
# Avoid division by 0 with `+1` for rank value in denominator before taking logarithm
445+
merged["__DCG"] = (merged[Columns.Rank] <= self.k).astype(int) / log_at_base(
446+
merged[Columns.Rank] + 1, self.log_base
439447
)
440-
return ndcg
448+
ranks = np.arange(1, self.k + 1)
449+
discounted_gains = 1 / log_at_base(ranks + 1, self.log_base)
450+
451+
if self.divide_by_achievable:
452+
grouped = merged.groupby(Columns.User, sort=False)
453+
stats = grouped.agg(n_items=(Columns.Item, "count"), dcg=("__DCG", "sum"))
454+
455+
# IDCG
456+
n_items_to_ndcg_map = dict(zip(ranks, discounted_gains.cumsum()))
457+
n_items_to_ndcg_map[0] = 0
458+
idcg = stats["n_items"].clip(upper=self.k).map(n_items_to_ndcg_map)
459+
460+
# NDCG
461+
ndcg = stats["dcg"] / idcg
462+
463+
else:
464+
idcg = discounted_gains.sum()
465+
ndcg = (
466+
pd.DataFrame({Columns.User: merged[Columns.User], "__ndcg": merged["__DCG"] / idcg})
467+
.groupby(Columns.User, sort=False)["__ndcg"]
468+
.sum()
469+
)
470+
471+
del merged["__DCG"]
472+
return ndcg.rename(None)
441473

442474

443475
class MRR(_RankingMetric):

rectools/metrics/scoring.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022-2024 MTS (Mobile Telesystems)
1+
# Copyright 2022-2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
2323

2424
from .auc import AucMetric, calc_auc_metrics
2525
from .base import Catalog, MetricAtK, merge_reco
26+
from .catalog import CatalogMetric, calc_catalog_metrics
2627
from .classification import ClassificationMetric, SimpleClassificationMetric, calc_classification_metrics
2728
from .diversity import DiversityMetric, calc_diversity_metrics
2829
from .dq import CrossDQMetric, RecoDQMetric, calc_cross_dq_metrics, calc_reco_dq_metrics
@@ -150,6 +151,14 @@ def calc_metrics( # noqa # pylint: disable=too-many-branches,too-many-locals,t
150151
novelty_values = calc_novelty_metrics(novelty_metrics, reco, prev_interactions)
151152
results.update(novelty_values)
152153

154+
# Catalog
155+
catalog_metrics = select_by_type(metrics, CatalogMetric)
156+
if catalog_metrics:
157+
if catalog is None:
158+
raise ValueError("For calculating catalog metrics it's necessary to set 'catalog'")
159+
catalog_values = calc_catalog_metrics(catalog_metrics, reco, catalog)
160+
results.update(catalog_values)
161+
153162
# Popularity
154163
popularity_metrics = select_by_type(metrics, PopularityMetric)
155164
if popularity_metrics:

tests/metrics/test_catalog.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2025 MTS (Mobile Telesystems)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=attribute-defined-outside-init
16+
17+
import numpy as np
18+
import pandas as pd
19+
20+
from rectools import Columns
21+
from rectools.metrics import CatalogCoverage
22+
23+
24+
class TestCatalogCoverage:
25+
def setup_method(self) -> None:
26+
self.metric = CatalogCoverage(k=2)
27+
self.reco = pd.DataFrame(
28+
{
29+
Columns.User: [1, 1, 1, 2, 2, 3, 4],
30+
Columns.Item: [1, 2, 3, 1, 2, 1, 1],
31+
Columns.Rank: [1, 2, 3, 1, 1, 3, 2],
32+
}
33+
)
34+
35+
def test_calc(self) -> None:
36+
catalog = np.arange(5)
37+
expected = 0.4
38+
assert self.metric.calc(self.reco, catalog) == expected

tests/metrics/test_ranking.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022-2024 MTS (Mobile Telesystems)
1+
# Copyright 2022-2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -99,13 +99,15 @@ class TestNDCG:
9999
_idcg_at_3 = 1 / np.log2(2) + 1 / np.log2(3) + 1 / np.log2(4)
100100

101101
@pytest.mark.parametrize(
102-
"k,expected_ndcg",
102+
"k,divide_by_achievable,expected_ndcg",
103103
(
104-
(1, [0, 0, 1, 1, 0]),
105-
(3, [0, 0, 1, 1 / _idcg_at_3, 0.5 / _idcg_at_3]),
104+
(1, False, [0, 0, 1, 1, 0]),
105+
(3, False, [0, 0, 1, 1 / _idcg_at_3, 0.5 / _idcg_at_3]),
106+
(1, True, [0, 0, 1, 1, 0]),
107+
(3, True, [0, 0, 1, 1, (1 / np.log2(4)) / (1 / np.log2(2))]),
106108
),
107109
)
108-
def test_calc(self, k: int, expected_ndcg: tp.List[float]) -> None:
110+
def test_calc(self, k: int, divide_by_achievable: bool, expected_ndcg: tp.List[float]) -> None:
109111
reco = pd.DataFrame(
110112
{
111113
Columns.User: [1, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6],
@@ -115,12 +117,12 @@ def test_calc(self, k: int, expected_ndcg: tp.List[float]) -> None:
115117
)
116118
interactions = pd.DataFrame(
117119
{
118-
Columns.User: [1, 2, 3, 3, 3, 4, 5, 5, 5, 5],
119-
Columns.Item: [1, 1, 1, 2, 3, 1, 1, 2, 3, 4],
120+
Columns.User: [1, 2, 3, 3, 3, 4, 5],
121+
Columns.Item: [1, 1, 1, 2, 3, 1, 1],
120122
}
121123
)
122124

123-
metric = NDCG(k=k)
125+
metric = NDCG(k=k, divide_by_achievable=divide_by_achievable)
124126
expected_metric_per_user = pd.Series(
125127
expected_ndcg,
126128
index=pd.Series([1, 2, 3, 4, 5], name=Columns.User),

tests/metrics/test_scoring.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022-2024 MTS (Mobile Telesystems)
1+
# Copyright 2022-2025 MTS (Mobile Telesystems)
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626
PAP,
2727
Accuracy,
2828
AvgRecPopularity,
29+
CatalogCoverage,
2930
CoveredUsers,
3031
DebiasConfig,
3132
F1Beta,
@@ -118,6 +119,7 @@ def test_success(self) -> None:
118119
"sufficient": SufficientReco(k=2),
119120
"unrepeated": UnrepeatedReco(k=2),
120121
"covered_users": CoveredUsers(k=2),
122+
"catalog_coverage": CatalogCoverage(k=2),
121123
}
122124
with pytest.warns(UserWarning, match="Custom metrics are not supported"):
123125
actual = calc_metrics(
@@ -147,6 +149,7 @@ def test_success(self) -> None:
147149
"sufficient": 0.25,
148150
"unrepeated": 1,
149151
"covered_users": 0.75,
152+
"catalog_coverage": 0.2,
150153
}
151154
assert actual == expected
152155

@@ -164,6 +167,7 @@ def test_success(self) -> None:
164167
(PartialAUC(k=1), ["reco"]),
165168
(Intersection(k=1), ["reco"]),
166169
(CoveredUsers(k=1), ["reco"]),
170+
(CatalogCoverage(k=1), ["reco"]),
167171
),
168172
)
169173
def test_raises(self, metric: MetricAtK, arg_names: tp.List[str]) -> None:

0 commit comments

Comments
 (0)