Skip to content

Commit 9b67c87

Browse files
authored
[Refactor]Refactor sampler (vllm-project#2050)
Refactor Sampler implementation from patch way to inherit from vLLM Sampler interface. Next step: Make the op `TopKTopPSampler` in vLLM support custom ops register mechanism - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@61a6905 Signed-off-by: wangxiyuan <[email protected]>
1 parent b6a7f07 commit 9b67c87

File tree

8 files changed

+108
-150
lines changed

8 files changed

+108
-150
lines changed

tests/ut/patch/worker/patch_common/test_patch_sampler.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

tests/ut/sample/test_sampler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from unittest import mock
2+
3+
import torch
4+
5+
from tests.ut.base import TestBase
6+
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
7+
8+
9+
class TestAscendSampler(TestBase):
10+
11+
def test_init_with_raw_logprobs(self):
12+
sampler = AscendSampler(logprobs_mode="raw_logprobs")
13+
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
14+
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
15+
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
16+
17+
18+
class TestAscendTopKTopPSampler(TestBase):
19+
20+
@mock.patch("torch_npu.npu_top_k_top_p")
21+
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
22+
mock_npu_op.return_value = (torch.randn(1, 3))
23+
sampler = AscendTopKTopPSampler()
24+
25+
logits = torch.tensor([[1.0, 2.0, 3.0]])
26+
k = torch.tensor([2])
27+
p = torch.tensor([0.9])
28+
generators = {0: torch.Generator()}
29+
generators[0].manual_seed(42)
30+
31+
sampler.forward_native(logits, generators, k, p)
32+
mock_npu_op.assert_called_once_with(logits, p, k)

vllm_ascend/envs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@
128128
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
129129
lambda: int(
130130
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
131-
# Whether to enable the topk optimization. It's disabled by default for experimental support
132-
# We'll make it enabled by default in the future.
131+
# Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue.
132+
# We'll remove this flag in the future once it's stable enough.
133133
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
134134
lambda: bool(
135-
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
135+
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))),
136136

137137
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
138138
# used for llmdatadist to build the communication topology for kv cache transfer, it is

vllm_ascend/patch/__init__.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,7 @@
8888
# Future Plan:
8989
# Remove this patch once pytorch 2.7.0 is supported for vllm ascend.
9090
#
91-
# ** File: worker/patch_common/patch_sampler.py **
92-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
93-
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
94-
# Why:
95-
# We need to use the patched `apply_top_k_top_p` in `sample`.
96-
# The mainly reason to overwrite `apply_top_k_top_p` is
97-
# to improve performance.
98-
# How:
99-
# Re-implementation the `apply_top_k_top_p` function by pytorch
100-
# Related PR (if no, explain why):
101-
# - https://github.com/vllm-project/vllm-ascend/pull/1732
102-
# Future Plan:
103-
# Revert it when the ascend scatter performance improves.
104-
#
105-
# ** File: worker/patch_common/patch_sampler.py **
91+
# ** File: worker/patch_0_10_0/patch_sampler_gather_logprobs.py **
10692
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
10793
# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs`
10894
# Why:

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@
2121
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
2222
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
2323
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
24-
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa

vllm_ascend/patch/worker/patch_common/patch_sampler.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

vllm_ascend/sample/sampler.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch_npu
3+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
4+
from vllm.v1.sample.sampler import Sampler
5+
6+
7+
class AscendSampler(Sampler):
8+
9+
def __init__(self, logprobs_mode="raw_logprobs"):
10+
# TODO: support logprobs_mode in vllm-ascend
11+
super().__init__(logprobs_mode=logprobs_mode)
12+
self.topk_topp_sampler = AscendTopKTopPSampler()
13+
14+
15+
class AscendTopKTopPSampler(TopKTopPSampler):
16+
17+
def _apply_top_k_top_p(
18+
self,
19+
logits: torch.Tensor,
20+
k: torch.Tensor,
21+
p: torch.Tensor,
22+
) -> torch.Tensor:
23+
if p is not None and k is not None:
24+
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
25+
return torch_npu.npu_top_k_top_p(logits, p, k)
26+
27+
if p is None and k is None:
28+
return logits
29+
30+
probs = logits.softmax(dim=-1)
31+
probs_sort, _ = probs.sort(dim=-1, descending=False)
32+
33+
if k is not None:
34+
top_k_count = probs_sort.size(1) - k.to(
35+
torch.long) # shape: (batch, )
36+
top_k_count = top_k_count.unsqueeze(dim=1)
37+
top_k_cutoff = probs_sort.gather(-1, top_k_count)
38+
39+
# Make sure the no top-k rows are no-op.
40+
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
41+
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
42+
43+
elements_to_discard = probs < top_k_cutoff
44+
logits.masked_fill_(elements_to_discard, -float("inf"))
45+
46+
if p is not None:
47+
cumprob = torch.cumsum(probs_sort, dim=-1)
48+
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
49+
top_p_mask[:, -1] = False # at least one
50+
51+
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
52+
top_p_cutoff = probs_sort.gather(-1, top_p_count)
53+
elements_to_discard = probs < top_p_cutoff
54+
logits.masked_fill_(elements_to_discard, -float("inf"))
55+
56+
return logits
57+
58+
def forward_native(self, logits, generators, k, p):
59+
"""Override pytorch native implementation to torch_npu"""
60+
logits = self._apply_top_k_top_p(logits, k, p)
61+
probs = logits.softmax(dim=-1, dtype=torch.float32)
62+
return random_sample(probs, generators)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@
6464
ModelRunnerOutput)
6565
from vllm.v1.pool.metadata import PoolingMetadata
6666
from vllm.v1.sample.metadata import SamplingMetadata
67-
from vllm.v1.sample.sampler import Sampler
6867
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
6968
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
7069
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
7170
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
7271
sanity_check_mm_encoder_outputs,
7372
scatter_mm_placeholders)
7473

74+
from vllm_ascend import envs
7575
from vllm_ascend.ascend_config import get_ascend_config
7676
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
7777
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
@@ -165,7 +165,15 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
165165
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
166166
self.device = device
167167
self.dtype = self.model_config.dtype
168-
self.sampler = Sampler()
168+
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
169+
# TODO: drop the env config to use ascend sampler by default
170+
from vllm_ascend.sample.sampler import AscendSampler
171+
172+
self.sampler = AscendSampler()
173+
else:
174+
from vllm.v1.sample.sampler import Sampler
175+
176+
self.sampler = Sampler()
169177

170178
# Lazy initialization, these will be set after __init__
171179
self.kv_caches: List[torch.Tensor] = []

0 commit comments

Comments
 (0)