Skip to content

Commit 36d9483

Browse files
author
chengyuan
committed
add ut for vocab_parallel_embedding
Signed-off-by: chengyuan <[email protected]>
1 parent e7d32ed commit 36d9483

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
# Adapted from vllm/tests/lora/test_layers.py
15+
16+
from unittest.mock import MagicMock, patch
17+
18+
import torch
19+
from tests.ut.base import TestBase
20+
from vllm.model_executor.layers.vocab_parallel_embedding import \
21+
VocabParallelEmbedding
22+
23+
from vllm_ascend.ops.vocab_parallel_embedding import (
24+
get_masked_input_and_mask, vocab_parallel_embedding_forward)
25+
26+
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
27+
28+
29+
class TestGetMaskedInputAndMask(TestBase):
30+
31+
def setUp(self):
32+
self.input_ = torch.arange(12)
33+
34+
def test_get_masked_input_and_mask(self):
35+
# tp 1 no padding
36+
input_modified, _ = get_masked_input_and_mask(
37+
self.input_,
38+
org_vocab_start_index=0,
39+
org_vocab_end_index=8,
40+
added_vocab_start_index=8,
41+
added_vocab_end_index=12,
42+
num_org_vocab_padding=0)
43+
assert torch.equal(self.input_, input_modified)
44+
45+
# tp 2 no padding
46+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
47+
org_vocab_start_index=0,
48+
org_vocab_end_index=4,
49+
added_vocab_start_index=8,
50+
added_vocab_end_index=10,
51+
num_org_vocab_padding=0)
52+
53+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
54+
org_vocab_start_index=4,
55+
org_vocab_end_index=8,
56+
added_vocab_start_index=10,
57+
added_vocab_end_index=12,
58+
num_org_vocab_padding=0)
59+
60+
assert torch.equal(input_rank_0,
61+
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
62+
assert torch.equal(input_rank_1,
63+
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
64+
65+
# tp 4 no padding
66+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
67+
org_vocab_start_index=0,
68+
org_vocab_end_index=2,
69+
added_vocab_start_index=8,
70+
added_vocab_end_index=9,
71+
num_org_vocab_padding=0)
72+
73+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
74+
org_vocab_start_index=2,
75+
org_vocab_end_index=4,
76+
added_vocab_start_index=9,
77+
added_vocab_end_index=10,
78+
num_org_vocab_padding=0)
79+
80+
input_rank_2, _ = get_masked_input_and_mask(self.input_,
81+
org_vocab_start_index=4,
82+
org_vocab_end_index=6,
83+
added_vocab_start_index=10,
84+
added_vocab_end_index=11,
85+
num_org_vocab_padding=0)
86+
87+
input_rank_3, _ = get_masked_input_and_mask(self.input_,
88+
org_vocab_start_index=6,
89+
org_vocab_end_index=8,
90+
added_vocab_start_index=11,
91+
added_vocab_end_index=12,
92+
num_org_vocab_padding=0)
93+
assert torch.equal(input_rank_0,
94+
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
95+
assert torch.equal(input_rank_1,
96+
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
97+
assert torch.equal(input_rank_2,
98+
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
99+
assert torch.equal(input_rank_3,
100+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
101+
102+
# tp 1 with padding
103+
input_modified, _ = get_masked_input_and_mask(
104+
self.input_,
105+
org_vocab_start_index=0,
106+
org_vocab_end_index=8,
107+
added_vocab_start_index=8,
108+
added_vocab_end_index=12,
109+
num_org_vocab_padding=2)
110+
assert torch.equal(
111+
input_modified,
112+
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
113+
114+
# tp 2 with padding
115+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
116+
org_vocab_start_index=0,
117+
org_vocab_end_index=4,
118+
added_vocab_start_index=8,
119+
added_vocab_end_index=10,
120+
num_org_vocab_padding=2)
121+
122+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
123+
org_vocab_start_index=4,
124+
org_vocab_end_index=8,
125+
added_vocab_start_index=10,
126+
added_vocab_end_index=12,
127+
num_org_vocab_padding=2)
128+
assert torch.equal(input_rank_0,
129+
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
130+
assert torch.equal(input_rank_1,
131+
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
132+
133+
# tp 4 with padding
134+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
135+
org_vocab_start_index=0,
136+
org_vocab_end_index=2,
137+
added_vocab_start_index=8,
138+
added_vocab_end_index=9,
139+
num_org_vocab_padding=2)
140+
141+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
142+
org_vocab_start_index=2,
143+
org_vocab_end_index=4,
144+
added_vocab_start_index=9,
145+
added_vocab_end_index=10,
146+
num_org_vocab_padding=2)
147+
148+
input_rank_2, _ = get_masked_input_and_mask(self.input_,
149+
org_vocab_start_index=4,
150+
org_vocab_end_index=6,
151+
added_vocab_start_index=10,
152+
added_vocab_end_index=11,
153+
num_org_vocab_padding=2)
154+
155+
input_rank_3, _ = get_masked_input_and_mask(self.input_,
156+
org_vocab_start_index=6,
157+
org_vocab_end_index=8,
158+
added_vocab_start_index=11,
159+
added_vocab_end_index=12,
160+
num_org_vocab_padding=2)
161+
assert torch.equal(input_rank_0,
162+
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
163+
assert torch.equal(input_rank_1,
164+
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
165+
assert torch.equal(input_rank_2,
166+
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
167+
assert torch.equal(input_rank_3,
168+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
169+
170+
171+
class TestVocabParallelEmbedding(TestBase):
172+
173+
def setUp(self):
174+
# Create a mock VocabParallelEmbedding instance
175+
self.mock_embedding = MagicMock(spec=VocabParallelEmbedding)
176+
self.mock_embedding.tp_size = 2 # Test with tensor parallelism
177+
self.mock_embedding.shard_indices = MagicMock()
178+
self.mock_embedding.shard_indices.org_vocab_start_index = 10
179+
self.mock_embedding.shard_indices.org_vocab_end_index = 20
180+
self.mock_embedding.shard_indices.num_org_vocab_padding = 5
181+
self.mock_embedding.shard_indices.added_vocab_start_index = 30
182+
self.mock_embedding.shard_indices.added_vocab_end_index = 40
183+
self.mock_embedding.quant_method = MagicMock()
184+
185+
# Set consistent embedding dimension for all tests
186+
self.embedding_dim = 10
187+
# Mock embedding returns tensor with shape (input_length, embedding_dim)
188+
self.mock_embedding.quant_method.embedding = MagicMock(
189+
side_effect=lambda _, x: torch.randn(x.shape[0], self.embedding_dim
190+
))
191+
192+
def test_get_masked_input_and_mask(self):
193+
"""Test the mask and offset calculation helper function."""
194+
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes all cases
195+
196+
masked_input, mask = get_masked_input_and_mask(
197+
input_,
198+
org_vocab_start_index=10,
199+
org_vocab_end_index=20,
200+
num_org_vocab_padding=5,
201+
added_vocab_start_index=30,
202+
added_vocab_end_index=40)
203+
204+
# The mask should be True for INVALID tokens (ones we want to mask out)
205+
expected_mask = torch.tensor([True, False, True, False, True])
206+
self.assertTrue(
207+
torch.equal(mask, expected_mask),
208+
f"Mask mismatch. Expected {expected_mask}, got {mask}")
209+
210+
# Check masked input values
211+
expected_masked = torch.tensor([0, 5, 0, 20, 0])
212+
self.assertTrue(
213+
torch.equal(masked_input, expected_masked),
214+
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
215+
)
216+
217+
def test_forward_with_tp_size_1(self):
218+
"""Test forward pass without tensor parallelism."""
219+
# Create a fresh mock embedding with tp_size=1
220+
mock_embedding = MagicMock(spec=VocabParallelEmbedding)
221+
mock_embedding.tp_size = 1
222+
mock_embedding.quant_method = MagicMock()
223+
mock_embedding.quant_method.embedding = MagicMock(
224+
return_value=torch.randn(3, self.embedding_dim))
225+
226+
input_ = torch.tensor([1, 2, 3])
227+
228+
with patch(
229+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
230+
side_effect=lambda x: x) as mock_reduce_tp1:
231+
output = vocab_parallel_embedding_forward(mock_embedding, input_)
232+
233+
# Should just pass through without masking
234+
mock_embedding.quant_method.embedding.assert_called_once_with(
235+
mock_embedding, input_.long())
236+
self.assertEqual(output.shape, (3, self.embedding_dim))
237+
238+
# Verify all_reduce was called once
239+
mock_reduce_tp1.assert_called_once()
240+
241+
def test_forward_with_tp(self):
242+
"""Test forward pass with tensor parallelism."""
243+
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
244+
with patch(
245+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
246+
side_effect=lambda x: x) as mock_reduce_tp:
247+
output = vocab_parallel_embedding_forward(self.mock_embedding,
248+
input_)
249+
250+
# Check that masking was applied correctly
251+
self.mock_embedding.quant_method.embedding.assert_called_once()
252+
called_input = self.mock_embedding.quant_method.embedding.call_args[0][
253+
1]
254+
expected_input = torch.tensor([5, 20]) # after offset calculation
255+
self.assertTrue(torch.all(called_input == expected_input))
256+
257+
# Check that all reduce was called
258+
# self.dist_mock.tensor_model_parallel_all_reduce.assert_called_once()
259+
mock_reduce_tp.assert_called_once()
260+
self.assertEqual(output.shape, (2, self.embedding_dim))
261+
262+
def test_forward_with_invalid_vocab(self):
263+
"""Test that invalid vocab indices are properly masked out."""
264+
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
265+
266+
# Create predictable mock output
267+
mock_output = torch.randn(5, self.embedding_dim)
268+
self.mock_embedding.quant_method.embedding = MagicMock(
269+
return_value=mock_output.clone())
270+
with patch(
271+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
272+
side_effect=lambda x: x):
273+
output = vocab_parallel_embedding_forward(self.mock_embedding,
274+
input_)
275+
276+
# Check that invalid positions (0, 2, 4) were zeroed out
277+
self.assertTrue(torch.all(output[0] == 0))
278+
self.assertTrue(torch.all(output[2] == 0))
279+
self.assertTrue(torch.all(output[4] == 0))
280+
self.assertTrue(torch.all(output[1] == mock_output[1]))
281+
self.assertTrue(torch.all(output[3] == mock_output[3]))
282+
self.assertEqual(output.shape, (5, self.embedding_dim))
283+
284+
def test_output_shape(self):
285+
"""Test that output shape is correct."""
286+
test_cases = [
287+
(torch.tensor([15]), (1, self.embedding_dim)),
288+
(torch.tensor([15, 35]), (2, self.embedding_dim)),
289+
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
290+
]
291+
292+
for input_, expected_shape in test_cases:
293+
with self.subTest(input=input_):
294+
with patch(
295+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
296+
side_effect=lambda x: x):
297+
output = vocab_parallel_embedding_forward(
298+
self.mock_embedding, input_)
299+
self.assertEqual(output.shape, expected_shape)

0 commit comments

Comments
 (0)