Skip to content

Commit 9201d05

Browse files
committed
refactor: support scenarios where top_p or top_k is None
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent c3b40a6 commit 9201d05

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

vllm_ascend/sample/ops/ascend_topk_topp_sampler.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,36 @@ def apply_top_k_top_p_npu(
3535
k: Optional[torch.Tensor],
3636
p: Optional[torch.Tensor],
3737
) -> torch.Tensor:
38-
"""Apply top-k and top-p optimized for NPU.
39-
40-
This algorithm avoids using torch.scatter which is time-consuming on NPU.
41-
"""
42-
# TODO(linfeng): consider the case taht either p or k is applied
38+
"""Apply top-k and/or top-p optimized for NPU."""
4339
if k is None and p is None:
4440
return logits
41+
4542
batch_size, vocab_size = logits.shape
43+
device = logits.device
4644
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
45+
if k is not None:
46+
safe_k = torch.clamp(k, min=1, max=vocab_size)
47+
boundary_idx = (vocab_size - safe_k).unsqueeze(1)
48+
boundary = logits_sort.gather(1, boundary_idx)
49+
top_k_mask = logits_sort < boundary
50+
logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf"))
51+
else:
52+
top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool)
4753

48-
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
49-
top_k_mask = logits_sort < boundary
50-
logits_sort.masked_fill_(top_k_mask, -float("inf"))
51-
cutoff = top_k_mask.sum(dim=-1).min()
52-
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
53-
probs_sum = probs_sort.cumsum(dim=-1)
54-
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
55-
top_p_mask[:, -1] = True
56-
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
57-
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
58-
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
54+
cutoffs = top_k_mask.sum(dim=-1)
55+
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=device).unsqueeze(1)
56+
if p is not None:
57+
global_cutoff = cutoffs.min()
58+
active_part = logits_idx[:, global_cutoff:]
59+
probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1)
60+
cumprob = probs_sort.cumsum(dim=-1)
61+
top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange(probs_sort.size(1), device=device) == probs_sort.size(1)-1)
62+
else:
63+
active_part = logits_idx
64+
top_p_mask = torch.arange(vocab_size, device=device).expand(batch_size, -1) >= cutoffs.unsqueeze(1)
5965

66+
valid_idx = (active_part + strides).masked_select(top_p_mask)
6067
logits_flatten = logits.flatten()
61-
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
62-
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
63-
logits[valid_idx] = valid_logits
64-
return logits.reshape(batch_size, vocab_size)
68+
output = torch.full_like(logits_flatten, -float('inf'))
69+
output[valid_idx] = logits_flatten[valid_idx]
70+
return output.reshape(batch_size, vocab_size)

0 commit comments

Comments
 (0)