Skip to content

Commit 555d6fd

Browse files
Merge branch 'vllm-project:main' into main
2 parents 10d3c8c + 7c90ba5 commit 555d6fd

File tree

4 files changed

+234
-1
lines changed

4 files changed

+234
-1
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import pytest
2+
import torch
3+
from pytest_mock import MockerFixture
4+
from transformers import PretrainedConfig
5+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
6+
7+
from tests.ut.base import PytestBase
8+
from vllm_ascend.models.deepseek_mtp import (
9+
CustomDeepSeekMTP, CustomDeepSeekMultiTokenPredictor,
10+
CustomDeepSeekMultiTokenPredictorLayer)
11+
12+
13+
class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
14+
15+
@pytest.fixture
16+
def setup_mtp_layer(self, mocker: MockerFixture):
17+
config = PretrainedConfig(vocab_size=1000,
18+
hidden_size=768,
19+
rms_norm_eps=1e-5)
20+
mocker.patch(
21+
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
22+
return_value=None)
23+
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
24+
return_value=None)
25+
mocker.patch(
26+
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
27+
return_value=None)
28+
mocker.patch(
29+
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__",
30+
return_value=None)
31+
mocker_deepseek_v2_decode_layer = mocker.patch(
32+
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
33+
return_value=None)
34+
35+
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
36+
mocker_deepseek_v2_decode_layer.assert_called_once()
37+
return mtp_layer
38+
39+
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
40+
mtp_layer = setup_mtp_layer
41+
assert isinstance(mtp_layer, CustomDeepSeekMultiTokenPredictorLayer)
42+
43+
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
44+
mtp_layer = setup_mtp_layer
45+
mocker.patch("torch.nn.Module.__setattr__")
46+
mocker.patch("torch.nn.Module.__getattr__")
47+
mocker.patch("torch.nn.Module.__delattr__")
48+
mocker.patch.object(mtp_layer,
49+
'eh_proj',
50+
return_value=torch.randn(2, 3, 768))
51+
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
52+
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
53+
torch.randn(2, 3, 768))
54+
55+
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
56+
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
57+
kv_cache = torch.randn(2, 3, 768)
58+
previous_hidden_states = torch.randn(2, 3, 768)
59+
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
60+
61+
output = mtp_layer(input_ids, positions, kv_cache, None,
62+
previous_hidden_states, inputs_embeds, 0)
63+
assert output.shape == (2, 3, 768)
64+
65+
66+
class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
67+
68+
@pytest.fixture
69+
def setup_predictor(self, mocker: MockerFixture):
70+
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
71+
mock_model_config = mocker.MagicMock(spec=ModelConfig)
72+
mock_hf_config = mocker.MagicMock()
73+
mock_hf_config.num_hidden_layers = 12
74+
mock_hf_config.num_nextn_predict_layers = 3
75+
mock_hf_config.vocab_size = 30000
76+
mock_model_config.hf_config = mock_hf_config
77+
mock_vllm_config.model_config = mock_model_config
78+
mock_vllm_config.cache_config = CacheConfig()
79+
mock_vllm_config.quant_config = mocker.MagicMock()
80+
mocker.patch(
81+
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
82+
return_value=None)
83+
84+
predictor = CustomDeepSeekMultiTokenPredictor(
85+
vllm_config=mock_vllm_config)
86+
return predictor
87+
88+
def test_init(self, mocker: MockerFixture, setup_predictor):
89+
predictor = setup_predictor
90+
assert predictor.num_mtp_layers == 3
91+
assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor)
92+
93+
@pytest.mark.parametrize('kv_caches, inputs_embeds', [
94+
(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]])),
95+
(None, None),
96+
])
97+
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
98+
inputs_embeds):
99+
predictor = setup_predictor
100+
mock_layer = mocker.MagicMock()
101+
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
102+
predictor.layers_list = [mock_layer]
103+
104+
# todo: need or not?
105+
# predictor.num_mtp_layers = 1
106+
input_ids = torch.tensor([[1, 2, 3]])
107+
positions = torch.tensor([[0, 1, 2]])
108+
mocker.patch(
109+
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
110+
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
111+
output = predictor.forward(input_ids, positions, kv_caches, None, None,
112+
inputs_embeds, 0)
113+
mock_layer.assert_called_once()
114+
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
115+
116+
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
117+
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
118+
predictor = setup_predictor
119+
120+
mock_layer = mocker.MagicMock()
121+
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
122+
predictor.layers_list = [mock_layer]
123+
mocker.patch("torch.nn.Module.__setattr__")
124+
mocker.patch("torch.nn.Module.__getattr__")
125+
mocker.patch("torch.nn.Module.__delattr__")
126+
mocker.patch(
127+
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
128+
return_value=None)
129+
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
130+
131+
result_logits = predictor.compute_logits(hidden_states=hidden_states,
132+
sampling_metadata=None)
133+
predictor.logits_processor.assert_called_once()
134+
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
135+
136+
137+
class TestCustomDeepSeekMTP(PytestBase):
138+
139+
@pytest.fixture
140+
def setup_mtp(self, mocker: MockerFixture):
141+
vllm_config = mocker.MagicMock()
142+
vllm_config.model_config.hf_config.num_hidden_layers = 12
143+
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
144+
vllm_config.cache_config = mocker.MagicMock()
145+
vllm_config.quant_config = mocker.MagicMock()
146+
147+
mocker.patch("torch.nn.Module.__setattr__")
148+
mocker.patch("torch.nn.Module.__getattr__")
149+
mocker.patch("torch.nn.Module.__delattr__")
150+
mocker.patch(
151+
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
152+
return_value=None)
153+
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
154+
return_value=None)
155+
156+
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
157+
return mtp
158+
159+
def test_init(self, mocker: MockerFixture, setup_mtp):
160+
mtp = setup_mtp
161+
assert isinstance(mtp, CustomDeepSeekMTP)
162+
163+
def test_forward(self, mocker: MockerFixture, setup_mtp):
164+
input_ids = torch.tensor([[1, 2, 3]])
165+
positions = torch.tensor([[0, 1, 2]])
166+
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
167+
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
168+
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
169+
spec_step_idx = 0
170+
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
171+
172+
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
173+
previous_hidden_states, inputs_embeds,
174+
spec_step_idx)
175+
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

tests/ut/models/test_qwen2_5_vl_without_padding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ def init_vision_transformer(
231231
vision_config.in_channels = 3
232232
vision_config.hidden_act = "gelu"
233233
vision_config.depth = 0
234+
vision_config.hidden_size = 1280
235+
vision_config.num_heads = 16
234236

235237
mocker.patch("torch.nn.Module.__setattr__")
236238
mocker.patch("torch.nn.Module.__getattr__")
@@ -239,6 +241,10 @@ def init_vision_transformer(
239241
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer.__init__",
240242
return_value=None,
241243
)
244+
mocker_vision_rotary_embedding = mocker.patch(
245+
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionRotaryEmbedding.__init__",
246+
return_value=None,
247+
)
242248
mocker.patch(
243249
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionBlock_Without_Padding.__init__",
244250
return_value=None,
@@ -264,7 +270,7 @@ def init_vision_transformer(
264270
args, kwargs = mocker_vit.call_args
265271
assert args == (vision_config, norm_eps, None, "")
266272
assert not kwargs
267-
273+
mocker_vision_rotary_embedding.assert_called_once()
268274
return vision_transformer
269275

270276
def test_init_vision_transformer(self, mocker: MockerFixture):
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
from pytest_mock import MockFixture
3+
4+
from tests.ut.base import PytestBase
5+
from vllm_ascend.multistream.decorator import set_multistream_support
6+
7+
8+
class Context:
9+
10+
def __init__(self, attn_metadata=None):
11+
self.attn_metadata = attn_metadata
12+
13+
14+
class TestDecorator(PytestBase):
15+
16+
@pytest.mark.parametrize(
17+
'layer_context, microbatch_context, expected_metadata', [
18+
((-1, None, None), -1, {
19+
"original": True
20+
}),
21+
((-1, None, None), 0, {
22+
"original": True
23+
}),
24+
((0, None, None), -1, {
25+
"original": True
26+
}),
27+
((0, None, [{
28+
"new": True
29+
}]), 0, {
30+
"new": True
31+
}),
32+
])
33+
def test_decorator(self, mocker: MockFixture, layer_context,
34+
microbatch_context, expected_metadata):
35+
36+
def context_func():
37+
return Context(attn_metadata={"original": True})
38+
39+
mocker.patch(
40+
'vllm_ascend.multistream.decorator.get_multistream_layer_context',
41+
return_value=layer_context)
42+
mocker.patch(
43+
'vllm_ascend.multistream.decorator.get_multistream_microbatch_context',
44+
return_value=microbatch_context)
45+
46+
context = set_multistream_support()(context_func)()
47+
assert context.attn_metadata == expected_metadata

vllm_ascend/models/qwen2_5_vl_without_padding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from vllm.model_executor.models.utils import maybe_prefix
4242
from vllm.multimodal import MULTIMODAL_REGISTRY
4343

44+
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
45+
4446

4547
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
4648

@@ -160,6 +162,9 @@ def __init__(
160162
super().__init__(vision_config, norm_eps, quant_config, prefix)
161163
norm_layer = partial(RMSNorm, eps=norm_eps)
162164
self.interleaved = interleaved
165+
head_dim = self.hidden_size // self.num_heads
166+
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
167+
2)
163168
self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding(
164169
patch_size=vision_config.patch_size,
165170
temporal_patch_size=vision_config.temporal_patch_size,

0 commit comments

Comments
 (0)