@@ -35,30 +35,36 @@ def apply_top_k_top_p_npu(
3535 k : Optional [torch .Tensor ],
3636 p : Optional [torch .Tensor ],
3737) -> torch .Tensor :
38- """Apply top-k and top-p optimized for NPU.
39-
40- This algorithm avoids using torch.scatter which is time-consuming on NPU.
41- """
42- # TODO(linfeng): consider the case taht either p or k is applied
38+ """Apply top-k and/or top-p optimized for NPU."""
4339 if k is None and p is None :
4440 return logits
41+
4542 batch_size , vocab_size = logits .shape
43+ device = logits .device
4644 logits_sort , logits_idx = logits .sort (dim = - 1 , descending = False )
45+ if k is not None :
46+ safe_k = torch .clamp (k , min = 1 , max = vocab_size )
47+ boundary_idx = (vocab_size - safe_k ).unsqueeze (1 )
48+ boundary = logits_sort .gather (1 , boundary_idx )
49+ top_k_mask = logits_sort < boundary
50+ logits_sort = logits_sort .masked_fill (top_k_mask , - float ("inf" ))
51+ else :
52+ top_k_mask = torch .zeros_like (logits_sort , dtype = torch .bool )
4753
48- boundary = logits_sort . gather ( 1 , ( vocab_size - k ). unsqueeze ( dim = 1 ) )
49- top_k_mask = logits_sort < boundary
50- logits_sort . masked_fill_ ( top_k_mask , - float ( "inf" ))
51- cutoff = top_k_mask . sum ( dim = - 1 ) .min ()
52- probs_sort = logits_sort . softmax ( dim = - 1 ) [:, cutoff :]
53- probs_sum = probs_sort . cumsum (dim = - 1 )
54- top_p_mask = probs_sum > 1 - p . unsqueeze (dim = 1 )
55- top_p_mask [:, - 1 ] = True
56- strides = torch . arange ( 0 , batch_size * vocab_size , vocab_size , device = logits . device )
57- flatten_idx = logits_idx [:, cutoff :] + strides . unsqueeze ( dim = 1 )
58- valid_idx = torch .masked_select ( flatten_idx , top_p_mask )
54+ cutoffs = top_k_mask . sum ( dim = - 1 )
55+ strides = torch . arange ( 0 , batch_size * vocab_size , vocab_size , device = device ). unsqueeze ( 1 )
56+ if p is not None :
57+ global_cutoff = cutoffs .min ()
58+ active_part = logits_idx [:, global_cutoff :]
59+ probs_sort = logits_sort [:, global_cutoff :]. softmax (dim = - 1 )
60+ cumprob = probs_sort . cumsum (dim = - 1 )
61+ top_p_mask = ( cumprob <= ( 1 - p . unsqueeze ( 1 ))) | ( torch . arange ( probs_sort . size ( 1 ), device = device ) == probs_sort . size ( 1 ) - 1 )
62+ else :
63+ active_part = logits_idx
64+ top_p_mask = torch .arange ( vocab_size , device = device ). expand ( batch_size , - 1 ) >= cutoffs . unsqueeze ( 1 )
5965
66+ valid_idx = (active_part + strides ).masked_select (top_p_mask )
6067 logits_flatten = logits .flatten ()
61- valid_logits = torch .index_select (logits_flatten , 0 , valid_idx )
62- logits = torch .empty_like (logits_flatten ).fill_ (- float ("inf" ))
63- logits [valid_idx ] = valid_logits
64- return logits .reshape (batch_size , vocab_size )
68+ output = torch .full_like (logits_flatten , - float ('inf' ))
69+ output [valid_idx ] = logits_flatten [valid_idx ]
70+ return output .reshape (batch_size , vocab_size )
0 commit comments