Skip to content

Commit 5fa1259

Browse files
Fixes on RecallAt metric for cut-off 1 and for scalar cut-off (#720)
* Fixed issues on Recall when a cut-off 1 was used and when a scalar cut-off was provided * Fixed exception when using a RankingMetric with a single cut-off (torch metrics was converting the result to scalar)
1 parent f4946bf commit 5fa1259

File tree

4 files changed

+89
-12
lines changed

4 files changed

+89
-12
lines changed

tests/unit/torch/test_ranking_metrics.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
import transformers4rec.torch as tr
21-
from transformers4rec.torch.ranking_metric import MeanReciprocalRankAt
21+
from transformers4rec.torch.ranking_metric import MeanReciprocalRankAt, RecallAt
2222

2323
# fixed parameters for tests
2424
list_metrics = list(tr.ranking_metric.ranking_metrics_registry.keys())
@@ -64,6 +64,58 @@ def test_mean_recipricol_rank():
6464
)
6565

6666

67+
def test_recall_at():
68+
metric = RecallAt([1, 2, 3, 4], labels_onehot=False)
69+
result = metric(
70+
torch.tensor(
71+
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]]
72+
),
73+
torch.tensor(
74+
[[0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0]]
75+
),
76+
)
77+
assert torch.all(
78+
torch.lt(
79+
torch.abs(torch.add(result, -torch.tensor([0.3333, 0.3333, 0.6667, 0.6667]))), 1e-3
80+
)
81+
)
82+
83+
84+
def test_recall_at_3d():
85+
metric = RecallAt([1, 2, 3, 4], labels_onehot=False)
86+
result = metric(
87+
torch.tensor(
88+
[
89+
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]],
90+
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]],
91+
]
92+
),
93+
torch.tensor(
94+
[
95+
[[0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0]],
96+
[[0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0]],
97+
]
98+
),
99+
)
100+
assert torch.all(
101+
torch.lt(torch.abs(torch.add(result, -torch.tensor([0.25, 0.25, 0.75, 0.75]))), 1e-3)
102+
)
103+
104+
105+
@pytest.mark.parametrize("cutoff", [4, [4]])
106+
def test_recall_at_single_metric(cutoff):
107+
metric = RecallAt(cutoff, labels_onehot=False)
108+
result = metric(
109+
torch.tensor(
110+
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]]
111+
),
112+
torch.tensor(
113+
[[0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0]]
114+
),
115+
)
116+
assert torch.all(torch.lt(torch.abs(torch.add(result, -torch.tensor([0.6667]))), 1e-3))
117+
118+
67119
# TODO: Compare the metrics @K between pytorch and numpy
68120
@pytest.mark.parametrize("metric", list_metrics)
69121
def test_numpy_comparison(metric):

tests/unit/torch/test_trainer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import transformers4rec.torch as tr
2525
from transformers4rec.config import trainer
2626
from transformers4rec.config import transformer as tconf
27+
from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt
2728

2829

2930
@pytest.mark.parametrize("batch_size", [16, 32])
@@ -327,6 +328,20 @@ def test_evaluate_results(torch_yoochoose_next_item_prediction_model):
327328
"eval_/next-item/recall_at_20",
328329
],
329330
),
331+
(
332+
tr.NextItemPredictionTask(
333+
weight_tying=False,
334+
metrics=(
335+
NDCGAt(top_ks=[5, 10], labels_onehot=True),
336+
RecallAt(top_ks=[10], labels_onehot=True),
337+
),
338+
),
339+
[
340+
"eval_/next-item/ndcg_at_5",
341+
"eval_/next-item/ndcg_at_10",
342+
"eval_/next-item/recall_at_10",
343+
],
344+
),
330345
(
331346
tr.BinaryClassificationTask("click", summary_type="mean"),
332347
[
@@ -347,7 +362,7 @@ def test_trainer_music_streaming(task_and_metrics):
347362
data = tr.data.music_streaming_testing_data
348363
schema = data.schema
349364
batch_size = 16
350-
task, default_metric = task_and_metrics
365+
task, expected_metrics = task_and_metrics
351366

352367
inputs = tr.TabularSequenceFeatures.from_schema(
353368
schema,
@@ -388,7 +403,7 @@ def test_trainer_music_streaming(task_and_metrics):
388403
predictions = recsys_trainer.predict(data.path)
389404

390405
assert isinstance(eval_metrics, dict)
391-
assert set(default_metric).issubset(set(eval_metrics.keys()))
406+
assert set(expected_metrics).issubset(set(eval_metrics.keys()))
392407
assert eval_metrics["eval_/loss"] is not None
393408

394409
assert predictions is not None

transformers4rec/torch/model/prediction_task.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ def calculate_metrics(self, predictions, targets) -> Dict[str, torch.Tensor]: #
486486
predictions = self.forward_to_prediction_fn(predictions)
487487

488488
for metric in self.metrics:
489-
outputs[self.metric_name(metric)] = metric(predictions, targets)
489+
result = metric(predictions, targets)
490+
outputs[self.metric_name(metric)] = result
490491

491492
return outputs
492493

@@ -502,6 +503,10 @@ def compute_metrics(self):
502503
topks = {self.metric_name(metric): metric.top_ks for metric in self.metrics}
503504
results = {}
504505
for name, metric in metrics.items():
506+
# Fix for when using a single cut-off, as torch metrics convert results to scalar
507+
# when a single element vector is returned
508+
if len(metric.size()) == 0:
509+
metric = metric.unsqueeze(0)
505510
for measure, k in zip(metric, topks[name]):
506511
results[f"{name}_{k}"] = measure
507512
return results

transformers4rec/torch/ranking_metric.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class RankingMetric(tm.Metric):
4242
def __init__(self, top_ks=None, labels_onehot=False):
4343
super(RankingMetric, self).__init__()
4444
self.top_ks = top_ks or [2, 5]
45+
if not isinstance(self.top_ks, (list, tuple)):
46+
self.top_ks = [self.top_ks]
47+
4548
self.labels_onehot = labels_onehot
4649
# Store the mean of the batch metrics (for each cut-off at topk)
4750
self.add_state("metric_mean", default=[], dist_reduce_fx="cat")
@@ -50,7 +53,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, **kwargs): # type:
5053
# Computing the metrics at different cut-offs
5154
if self.labels_onehot:
5255
target = torch_utils.tranform_label_to_onehot(target, preds.size(-1))
53-
metric = self._metric(self.top_ks, preds.view(-1, preds.size(-1)), target)
56+
metric = self._metric(
57+
self.top_ks, preds.view(-1, preds.size(-1)), target.view(-1, target.size(-1))
58+
)
5459
self.metric_mean.append(metric) # type: ignore
5560

5661
def compute(self):
@@ -126,17 +131,17 @@ def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor)
126131

127132
# Compute recalls at K
128133
num_relevant = torch.sum(labels, dim=-1)
129-
rel_indices = (num_relevant != 0).nonzero()
130-
rel_count = num_relevant[rel_indices].squeeze()
134+
rel_indices = (num_relevant != 0).nonzero().squeeze()
135+
rel_count = num_relevant[rel_indices]
131136

132137
if rel_indices.shape[0] > 0:
133138
for index, k in enumerate(ks):
134-
rel_labels = topk_labels[rel_indices, : int(k)].squeeze()
139+
rel_labels = topk_labels[rel_indices, : int(k)]
135140

136-
recalls[rel_indices, index] = (
137-
torch.div(torch.sum(rel_labels, dim=-1), rel_count)
138-
.reshape(len(rel_indices), 1)
139-
.to(dtype=torch.float32)
141+
recalls[rel_indices, index] = torch.div(
142+
torch.sum(rel_labels, dim=-1), rel_count
143+
).to(
144+
dtype=torch.float32
140145
) # Ensuring type is double, because it can be float if --fp16
141146

142147
return recalls

0 commit comments

Comments
 (0)