@@ -64,7 +64,7 @@ def precision(
6464 y_true , y_pred , k = _format_inputs (y_true , y_pred , k )
6565 assert y_true .max () == 1 , "The input y_true must be binary."
6666 pred_seq = y_true .gather (1 , torch .argsort (y_pred , dim = - 1 , descending = True ))[:, :k ]
67- res_list = pred_seq .sum (dim = 1 ) / k
67+ res_list = ( pred_seq .sum (dim = 1 ) / k ). detach (). cpu ()
6868 if ret_batch :
6969 return res_list
7070 else :
@@ -99,7 +99,7 @@ def recall(
9999 assert y_true .max () == 1 , "The input y_true must be binary."
100100 pred_seq = y_true .gather (1 , torch .argsort (y_pred , dim = - 1 , descending = True ))[:, :k ]
101101 num_true = y_true .sum (dim = 1 )
102- res_list = pred_seq .sum (dim = 1 ) / num_true
102+ res_list = ( pred_seq .sum (dim = 1 ) / num_true ). detach (). cpu ()
103103 if ret_batch :
104104 return res_list
105105 else :
@@ -140,7 +140,7 @@ def ap(
140140 if method == "pascal_voc" :
141141 res = torch .flip (res , dims = (0 ,))
142142 res = torch .cummax (res , dim = 0 )[0 ]
143- return res .mean ().item ()
143+ return res .detach (). cpu (). mean ().item ()
144144
145145
146146def map (
@@ -193,7 +193,7 @@ def _dcg(matrix: torch.Tensor) -> torch.Tensor:
193193 assert matrix .dim () == 2 , "The input must be a 2-D tensor."
194194 n , k = matrix .shape
195195 denom = torch .log2 (torch .arange (k , device = matrix .device ) + 2.0 ).view (1 , - 1 ).repeat (n , 1 )
196- return (matrix / denom ).sum (dim = - 1 )
196+ return (matrix / denom ).detach (). cpu (). sum (dim = - 1 )
197197
198198
199199def ndcg (
@@ -225,7 +225,7 @@ def ndcg(
225225 pred_dcg = _dcg (pred_seq )
226226 ideal_dcg = _dcg (ideal_seq )
227227
228- res_list = pred_dcg / ideal_dcg
228+ res_list = ( pred_dcg / ideal_dcg ). detach (). cpu ()
229229 res_list [torch .isinf (res_list )] = 0
230230 if ret_batch :
231231 return res_list
@@ -260,7 +260,7 @@ def rr(y_true: torch.Tensor, y_pred: torch.Tensor, k: Optional[int] = None) -> f
260260
261261 pred_seq = y_true [torch .argsort (y_pred , dim = - 1 , descending = True )][:k ]
262262 pred_index = torch .nonzero (pred_seq ).view (- 1 )
263- res = (1 / (pred_index + 1 )).mean ()
263+ res = (1 / (pred_index + 1 )).mean (). detach (). cpu ()
264264 return res .mean ().item ()
265265
266266
0 commit comments