Skip to content

Commit 9b6dd3e

Browse files
authored
Merge pull request #16 from jiishy/0.9.2
jsy add ndcg test
2 parents eca4185 + ac2db34 commit 9b6dd3e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/metrics/test_recommender.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import torch
4+
import math
45
from sklearn.metrics import ndcg_score
56
import dhg.metrics.recommender as dm
67

@@ -30,3 +31,13 @@ def test_ndcg():
3031
assert dm.ndcg(y_true, y_score, k=3) == pytest.approx(ndcg_score(y_true, y_score, k=3))
3132
assert dm.ndcg(y_true, y_score, k=4) == pytest.approx(ndcg_score(y_true, y_score, k=4))
3233
assert dm.ndcg(y_true, y_score, k=5) == pytest.approx(ndcg_score(y_true, y_score, k=5))
34+
35+
y_true = torch.tensor([0, 1, 0, 0, 1, 1])
36+
y_pred = torch.tensor([0.8, 0.9, 0.6, 0.7, 0.4, 0.5])
37+
assert dm.ndcg(y_true, y_pred, k=2) == pytest.approx((1 / math.log2(2)) / (1 / math.log2(2) + 1 / math.log2(3)))
38+
assert dm.ndcg(y_true, y_pred, k=3) == pytest.approx((1 / math.log2(2)) / (1 / math.log2(2) + 1 / math.log2(3) + 1 / math.log2(4)))
39+
assert dm.ndcg(y_true, y_pred, k=5) == pytest.approx((1 / math.log2(2) + 1 / math.log2(6)) / (1 / math.log2(2) + 1 / math.log2(3) + 1 / math.log2(4)))
40+
41+
y_true = torch.tensor([3, 2, 3, 0, 1, 2, 3, 2])
42+
y_pred = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2])
43+
assert dm.ndcg(y_true, y_pred, k=6) == pytest.approx(0.785, abs=1e-4)

0 commit comments

Comments
 (0)