Skip to content

Commit 28a1b7b

Browse files
committed
fix retrieval metric class bugs
1 parent 00524b0 commit 28a1b7b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

dhg/metrics/retrieval.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def precision(
6464
y_true, y_pred, k = _format_inputs(y_true, y_pred, k)
6565
assert y_true.max() == 1, "The input y_true must be binary."
6666
pred_seq = y_true.gather(1, torch.argsort(y_pred, dim=-1, descending=True))[:, :k]
67-
res_list = pred_seq.sum(dim=1) / k
67+
res_list = (pred_seq.sum(dim=1) / k).detach().cpu()
6868
if ret_batch:
6969
return res_list
7070
else:
@@ -99,7 +99,7 @@ def recall(
9999
assert y_true.max() == 1, "The input y_true must be binary."
100100
pred_seq = y_true.gather(1, torch.argsort(y_pred, dim=-1, descending=True))[:, :k]
101101
num_true = y_true.sum(dim=1)
102-
res_list = pred_seq.sum(dim=1) / num_true
102+
res_list = (pred_seq.sum(dim=1) / num_true).detach().cpu()
103103
if ret_batch:
104104
return res_list
105105
else:
@@ -140,7 +140,7 @@ def ap(
140140
if method == "pascal_voc":
141141
res = torch.flip(res, dims=(0,))
142142
res = torch.cummax(res, dim=0)[0]
143-
return res.mean().item()
143+
return res.detach().cpu().mean().item()
144144

145145

146146
def map(
@@ -193,7 +193,7 @@ def _dcg(matrix: torch.Tensor) -> torch.Tensor:
193193
assert matrix.dim() == 2, "The input must be a 2-D tensor."
194194
n, k = matrix.shape
195195
denom = torch.log2(torch.arange(k, device=matrix.device) + 2.0).view(1, -1).repeat(n, 1)
196-
return (matrix / denom).sum(dim=-1)
196+
return (matrix / denom).detach().cpu().sum(dim=-1)
197197

198198

199199
def ndcg(
@@ -225,7 +225,7 @@ def ndcg(
225225
pred_dcg = _dcg(pred_seq)
226226
ideal_dcg = _dcg(ideal_seq)
227227

228-
res_list = pred_dcg / ideal_dcg
228+
res_list = (pred_dcg / ideal_dcg).detach().cpu()
229229
res_list[torch.isinf(res_list)] = 0
230230
if ret_batch:
231231
return res_list
@@ -260,7 +260,7 @@ def rr(y_true: torch.Tensor, y_pred: torch.Tensor, k: Optional[int] = None) -> f
260260

261261
pred_seq = y_true[torch.argsort(y_pred, dim=-1, descending=True)][:k]
262262
pred_index = torch.nonzero(pred_seq).view(-1)
263-
res = (1 / (pred_index + 1)).mean()
263+
res = (1 / (pred_index + 1)).mean().detach().cpu()
264264
return res.mean().item()
265265

266266

0 commit comments

Comments
 (0)