Skip to content

Commit 7f46d64

Browse files
committed
refactor: support scenarios where top_p or top_k is None
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 3eb58d7 commit 7f46d64

File tree

6 files changed

+64
-97
lines changed

6 files changed

+64
-97
lines changed
Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
from typing import Dict, Optional
22

33
import torch
4-
import torch.nn as nn
5-
6-
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
74
from vllm.logger import init_logger
8-
5+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
96

107
logger = init_logger(__name__)
118

129

1310
class AscendTopKTopPSampler(TopKTopPSampler):
1411

15-
def __init__(self):
16-
super().__init__()
17-
# TODO(linfeng): eliminate warning for FlashInfer here
18-
self.forward = self.forward_npu
19-
20-
def forward_npu(
12+
def forward_native(
2113
self,
2214
logits: torch.Tensor,
2315
generators: Dict[int, torch.Generator],
@@ -28,37 +20,48 @@ def forward_npu(
2820
logits = apply_top_k_top_p_npu(logits, k, p)
2921
probs = logits.softmax(dim=-1, dtype=torch.float32)
3022
return random_sample(probs, generators)
31-
23+
3224

3325
def apply_top_k_top_p_npu(
3426
logits: torch.Tensor,
3527
k: Optional[torch.Tensor],
3628
p: Optional[torch.Tensor],
3729
) -> torch.Tensor:
38-
"""Apply top-k and top-p optimized for NPU.
39-
40-
This algorithm avoids using torch.scatter which is time-consuming on NPU.
41-
"""
42-
# TODO(linfeng): consider the case taht either p or k is applied
30+
"""Apply top-k and/or top-p optimized for NPU."""
4331
if k is None and p is None:
4432
return logits
33+
4534
batch_size, vocab_size = logits.shape
35+
device = logits.device
4636
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
37+
if k is not None:
38+
safe_k = torch.clamp(k, min=1, max=vocab_size)
39+
boundary_idx = (vocab_size - safe_k).unsqueeze(1)
40+
boundary = logits_sort.gather(1, boundary_idx)
41+
top_k_mask = logits_sort < boundary
42+
logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf"))
43+
else:
44+
top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool)
4745

48-
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
49-
top_k_mask = logits_sort < boundary
50-
logits_sort.masked_fill_(top_k_mask, -float("inf"))
51-
cutoff = top_k_mask.sum(dim=-1).min()
52-
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
53-
probs_sum = probs_sort.cumsum(dim=-1)
54-
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
55-
top_p_mask[:, -1] = True
56-
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
57-
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
58-
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
46+
cutoffs = top_k_mask.sum(dim=-1)
47+
strides = torch.arange(0,
48+
batch_size * vocab_size,
49+
vocab_size,
50+
device=device).unsqueeze(1)
51+
if p is not None:
52+
global_cutoff = cutoffs.min()
53+
active_part = logits_idx[:, global_cutoff:]
54+
probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1)
55+
cumprob = probs_sort.cumsum(dim=-1)
56+
top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange(
57+
probs_sort.size(1), device=device) == probs_sort.size(1) - 1)
58+
else:
59+
active_part = logits_idx
60+
top_p_mask = torch.arange(vocab_size, device=device).expand(
61+
batch_size, -1) >= cutoffs.unsqueeze(1)
5962

63+
valid_idx = (active_part + strides).masked_select(top_p_mask)
6064
logits_flatten = logits.flatten()
61-
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
62-
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
63-
logits[valid_idx] = valid_logits
64-
return logits.reshape(batch_size, vocab_size)
65+
output = torch.full_like(logits_flatten, -float('inf'))
66+
output[valid_idx] = logits_flatten[valid_idx]
67+
return output.reshape(batch_size, vocab_size)

vllm_ascend/sample/ops/penalties.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
5-
from vllm.v1.sample.ops.penalties import _convert_to_tensors
64
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
5+
from vllm.v1.sample.ops.penalties import _convert_to_tensors
76

87

98
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
@@ -31,23 +30,25 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
3130
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
3231
output_tokens_tensor, vocab_size, num_seqs)
3332

34-
3533
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
3634
1, vocab_size)
37-
35+
3836
# Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU.
3937
sequence_mask = prompt_mask | output_mask
40-
logits = torch.where(sequence_mask & torch.lt(logits, 0), logits * repetition_penalties,
41-
logits).to(logits.dtype)
42-
logits = torch.where(sequence_mask & torch.ge(logits, 0), logits / repetition_penalties,
43-
logits).to(logits.dtype)
38+
logits = torch.where(sequence_mask & torch.lt(logits, 0),
39+
logits * repetition_penalties,
40+
logits).to(logits.dtype)
41+
logits = torch.where(sequence_mask & torch.ge(logits, 0),
42+
logits / repetition_penalties,
43+
logits).to(logits.dtype)
4444

4545
# We follow the definition in OpenAI API.
4646
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
4747
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
4848
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
4949
return logits
5050

51+
5152
def apply_all_penalties(
5253
logits: torch.Tensor,
5354
prompt_token_ids: torch.Tensor,
@@ -64,4 +65,4 @@ def apply_all_penalties(
6465
logits.device)
6566
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
6667
presence_penalties, frequency_penalties,
67-
repetition_penalties)
68+
repetition_penalties)

vllm_ascend/sample/sampler.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
from typing import Optional
44

55
import torch
6-
from vllm.model_executor.layers.sampler import (Sampler,
7-
SamplerOutput,
8-
_apply_min_tokens_penalty,
9-
_apply_min_p,
10-
_sample,
11-
SampleResultArgsType,
12-
get_logprobs,
13-
_build_sampler_output)
6+
from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType,
7+
SamplerOutput, _apply_min_p,
8+
_apply_min_tokens_penalty,
9+
_build_sampler_output, _sample,
10+
get_logprobs)
1411
from vllm.model_executor.sampling_metadata import SamplingMetadata
12+
1513
from vllm_ascend.sample.ops.penalties import apply_penalties
1614

1715

@@ -61,7 +59,7 @@ def forward(
6159

6260
if do_top_p_top_k:
6361
logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps,
64-
sampling_tensors.top_ks)
62+
sampling_tensors.top_ks)
6563

6664
if do_min_p:
6765
logits = _apply_min_p(logits, sampling_tensors.min_ps)
@@ -83,21 +81,15 @@ def forward(
8381
)
8482

8583
if self.include_gpu_probs_tensor:
86-
# Since we will defer sampler result Pythonization,
87-
# preserve GPU-side tensors in support of later
88-
# deferred pythonization of logprobs
8984
assert maybe_sampled_tokens_tensor is not None
9085
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
9186
else:
92-
# Since Pythonization has already happened, don't preserve
93-
# GPU-side tensors.
9487
on_device_tensors = None
9588

9689
# Get the logprobs query results.
9790
prompt_logprobs = None
9891
sample_logprobs = None
9992
if not sampling_metadata.skip_sampler_cpu_output:
100-
# Pythonize logprobs now (GPU -> CPU); do not defer.
10193
assert not isinstance(maybe_deferred_sample_results,
10294
SampleResultArgsType)
10395
prompt_logprobs, sample_logprobs = get_logprobs(
@@ -121,10 +113,9 @@ def _apply_top_k_top_p_npu(
121113
122114
This algorithm avoids using torch.scatter which is time-consuming on NPU.
123115
"""
124-
# TODO(linfeng): consider the case taht either p or k is applied
125116
batch_size, vocab_size = logits.shape
126117
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
127-
118+
128119
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
129120
top_k_mask = logits_sort < boundary
130121
logits_sort.masked_fill_(top_k_mask, -float("inf"))
@@ -133,7 +124,10 @@ def _apply_top_k_top_p_npu(
133124
probs_sum = probs_sort.cumsum(dim=-1)
134125
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
135126
top_p_mask[:, -1] = True
136-
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
127+
strides = torch.arange(0,
128+
batch_size * vocab_size,
129+
vocab_size,
130+
device=logits.device)
137131
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
138132
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
139133
logits_flatten = logits.flatten()

vllm_ascend/sample/sampler_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
2-
from vllm.v1.sample.sampler import Sampler
2+
from vllm.logger import init_logger
33
from vllm.v1.sample.metadata import SamplingMetadata
44
from vllm.v1.sample.ops.penalties import apply_min_token_penalties
5-
from vllm.logger import init_logger
6-
from vllm_ascend.sample.ops.ascend_topk_topp_sampler import AscendTopKTopPSampler
7-
from vllm_ascend.sample.ops.penalties import apply_all_penalties
5+
from vllm.v1.sample.sampler import Sampler
86

7+
from vllm_ascend.sample.ops.ascend_topk_topp_sampler import \
8+
AscendTopKTopPSampler
9+
from vllm_ascend.sample.ops.penalties import apply_all_penalties
910

1011
logger = init_logger(__name__)
1112

vllm_ascend/worker/model_runner.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
_add_sampling_metadata_broadcastable_dict,
6161
_init_attn_metadata_from_tensor_dict,
6262
_init_sampling_metadata_from_tensor_dict)
63+
6364
from vllm_ascend.sample.sampler import AscendSampler
6465

6566
if TYPE_CHECKING:
@@ -823,12 +824,7 @@ def load_model(self) -> None:
823824
logger.info("Starting to load model %s...", self.model_config.model)
824825
with DeviceMemoryProfiler() as m:
825826
self.model = get_model(vllm_config=self.vllm_config)
826-
# Same options with those in model_runner_v1.py
827-
# option 1
828-
if hasattr(self.model, "sampler"):
829-
self.model.sampler = AscendSampler()
830-
# option 2
831-
# self.model = NPUModelWrapperV1(model)
827+
self.model.sampler = AscendSampler()
832828
self.model_memory_usage = m.consumed_memory
833829
logger.info("Loading model weights took %.4f GB",
834830
self.model_memory_usage / float(2**30))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
from vllm.inputs import INPUT_REGISTRY
3434
from vllm.logger import init_logger
3535
from vllm.model_executor.layers.fused_moe import FusedMoE
36-
from vllm.model_executor.layers.sampler import sampler_output
3736
from vllm.model_executor.model_loader import get_model
38-
from vllm.model_executor.sampling_metadata import SamplingMetadata
3937
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
4038
from vllm.platforms import current_platform
4139
from vllm.sampling_params import SamplingType
@@ -808,11 +806,7 @@ def load_model(self) -> None:
808806

809807
with DeviceMemoryProfiler() as m: # noqa: SIM117
810808
self.model = get_model(vllm_config=self.vllm_config)
811-
# option 1
812-
if hasattr(self.model, "sampler"):
813-
self.model.sampler = AscendSampler()
814-
# option 2
815-
# self.model = NPUModelWrapperV1(model)
809+
self.model.sampler = AscendSampler()
816810

817811
if self.lora_config:
818812
raise ValueError("LoRA model is not supported on NPU now.")
@@ -893,25 +887,3 @@ def get_kv_cache_spec(self) -> KVCacheSpec:
893887
f"Unknown attention type: {attn_module.attn_type}")
894888

895889
return kv_cache_spec
896-
897-
# class NPUModelWrapperV1(nn.Module):
898-
899-
# def __init__(self, model: nn.Module):
900-
# super().__init__()
901-
# self._model = model
902-
# self.sampler = AscendSampler()
903-
904-
# def __getattr__(self, name):
905-
# return getattr(self._model, name)
906-
907-
# def sample(
908-
# self,
909-
# logits: Optional[torch.Tensor],
910-
# sampling_metadata: SamplingMetadata,
911-
# ) -> Optional[SamplerOutput]:
912-
# next_tokens = self.sampler(logits, sampling_metadata)
913-
# return next_tokens
914-
915-
# def forward():
916-
# # necessary if using wrapper class
917-
# pass

0 commit comments

Comments
 (0)