Skip to content

Commit e124154

Browse files
committed
fix bug (ndcg and recall nan)
1 parent 9fe043d commit e124154

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

dhg/metrics/recommender.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def recall(
112112
assert y_true.max() == 1, "The input y_true must be binary."
113113
pred_seq = y_true.gather(1, torch.argsort(y_pred, dim=-1, descending=True))[:, :k]
114114
num_true = y_true.sum(dim=1)
115-
res_list = (pred_seq.sum(dim=1) / num_true).detach().cpu()
115+
res_list = (pred_seq.sum(dim=1) / num_true).cpu()
116+
res_list[torch.isinf(res_list)] = 0
117+
res_list[torch.isnan(res_list)] = 0
116118
if ret_batch:
117119
return [res.item() for res in res_list]
118120
else:
@@ -169,6 +171,7 @@ def ndcg(
169171

170172
res_list = (pred_dcg / ideal_dcg).detach().cpu()
171173
res_list[torch.isinf(res_list)] = 0
174+
res_list[torch.isnan(res_list)] = 0
172175
if ret_batch:
173176
return [res.item() for res in res_list]
174177
else:

dhg/metrics/retrieval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def recall(
104104
assert y_true.max() == 1, "The input y_true must be binary."
105105
pred_seq = y_true.gather(1, torch.argsort(y_pred, dim=-1, descending=True))[:, :k]
106106
num_true = y_true.sum(dim=1)
107-
res_list = (pred_seq.sum(dim=1) / num_true).detach().cpu()
107+
res_list = (pred_seq.sum(dim=1) / num_true).cpu()
108+
res_list[torch.isinf(res_list)] = 0
109+
res_list[torch.isnan(res_list)] = 0
108110
if ret_batch:
109111
return res_list
110112
else:
@@ -252,6 +254,7 @@ def ndcg(
252254

253255
res_list = pred_dcg / ideal_dcg
254256
res_list[torch.isinf(res_list)] = 0
257+
res_list[torch.isnan(res_list)] = 0
255258
if ret_batch:
256259
return res_list
257260
else:

0 commit comments

Comments
 (0)