Skip to content

Commit 834c599

Browse files
author
chengyuan
committed
add ut for vocab_parallel_embedding
Signed-off-by: chengyuan <[email protected]>
1 parent e7d32ed commit 834c599

File tree

1 file changed

+297
-0
lines changed

1 file changed

+297
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
16+
import unittest
17+
from unittest.mock import MagicMock, patch
18+
19+
import torch
20+
from vllm.model_executor.layers.vocab_parallel_embedding import \
21+
VocabParallelEmbedding
22+
from vllm_ascend.ops.vocab_parallel_embedding import (
23+
get_masked_input_and_mask, vocab_parallel_embedding_forward)
24+
25+
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
26+
27+
28+
class TestGetMaskedInputAndMask(unittest.TestCase):
29+
30+
def setUp(self):
31+
self.input_ = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
32+
33+
def test_get_masked_input_and_mask(self):
34+
# tp 1 no padding
35+
input_modified, _ = get_masked_input_and_mask(
36+
self.input_,
37+
org_vocab_start_index=0,
38+
org_vocab_end_index=8,
39+
added_vocab_start_index=8,
40+
added_vocab_end_index=12,
41+
num_org_vocab_padding=0)
42+
assert torch.equal(self.input_, input_modified)
43+
44+
# tp 2 no padding
45+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
46+
org_vocab_start_index=0,
47+
org_vocab_end_index=4,
48+
added_vocab_start_index=8,
49+
added_vocab_end_index=10,
50+
num_org_vocab_padding=0)
51+
52+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
53+
org_vocab_start_index=4,
54+
org_vocab_end_index=8,
55+
added_vocab_start_index=10,
56+
added_vocab_end_index=12,
57+
num_org_vocab_padding=0)
58+
59+
assert torch.equal(input_rank_0,
60+
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
61+
assert torch.equal(input_rank_1,
62+
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
63+
64+
# tp 4 no padding
65+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
66+
org_vocab_start_index=0,
67+
org_vocab_end_index=2,
68+
added_vocab_start_index=8,
69+
added_vocab_end_index=9,
70+
num_org_vocab_padding=0)
71+
72+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
73+
org_vocab_start_index=2,
74+
org_vocab_end_index=4,
75+
added_vocab_start_index=9,
76+
added_vocab_end_index=10,
77+
num_org_vocab_padding=0)
78+
79+
input_rank_2, _ = get_masked_input_and_mask(self.input_,
80+
org_vocab_start_index=4,
81+
org_vocab_end_index=6,
82+
added_vocab_start_index=10,
83+
added_vocab_end_index=11,
84+
num_org_vocab_padding=0)
85+
86+
input_rank_3, _ = get_masked_input_and_mask(self.input_,
87+
org_vocab_start_index=6,
88+
org_vocab_end_index=8,
89+
added_vocab_start_index=11,
90+
added_vocab_end_index=12,
91+
num_org_vocab_padding=0)
92+
assert torch.equal(input_rank_0,
93+
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
94+
assert torch.equal(input_rank_1,
95+
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
96+
assert torch.equal(input_rank_2,
97+
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
98+
assert torch.equal(input_rank_3,
99+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
100+
101+
# tp 1 with padding
102+
input_modified, _ = get_masked_input_and_mask(
103+
self.input_,
104+
org_vocab_start_index=0,
105+
org_vocab_end_index=8,
106+
added_vocab_start_index=8,
107+
added_vocab_end_index=12,
108+
num_org_vocab_padding=2)
109+
assert torch.equal(
110+
input_modified,
111+
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
112+
113+
# tp 2 with padding
114+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
115+
org_vocab_start_index=0,
116+
org_vocab_end_index=4,
117+
added_vocab_start_index=8,
118+
added_vocab_end_index=10,
119+
num_org_vocab_padding=2)
120+
121+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
122+
org_vocab_start_index=4,
123+
org_vocab_end_index=8,
124+
added_vocab_start_index=10,
125+
added_vocab_end_index=12,
126+
num_org_vocab_padding=2)
127+
assert torch.equal(input_rank_0,
128+
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
129+
assert torch.equal(input_rank_1,
130+
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
131+
132+
# tp 4 with padding
133+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
134+
org_vocab_start_index=0,
135+
org_vocab_end_index=2,
136+
added_vocab_start_index=8,
137+
added_vocab_end_index=9,
138+
num_org_vocab_padding=2)
139+
140+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
141+
org_vocab_start_index=2,
142+
org_vocab_end_index=4,
143+
added_vocab_start_index=9,
144+
added_vocab_end_index=10,
145+
num_org_vocab_padding=2)
146+
147+
input_rank_2, _ = get_masked_input_and_mask(self.input_,
148+
org_vocab_start_index=4,
149+
org_vocab_end_index=6,
150+
added_vocab_start_index=10,
151+
added_vocab_end_index=11,
152+
num_org_vocab_padding=2)
153+
154+
input_rank_3, _ = get_masked_input_and_mask(self.input_,
155+
org_vocab_start_index=6,
156+
org_vocab_end_index=8,
157+
added_vocab_start_index=11,
158+
added_vocab_end_index=12,
159+
num_org_vocab_padding=2)
160+
assert torch.equal(input_rank_0,
161+
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
162+
assert torch.equal(input_rank_1,
163+
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
164+
assert torch.equal(input_rank_2,
165+
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
166+
assert torch.equal(input_rank_3,
167+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
168+
169+
class TestVocabParallelEmbedding(unittest.TestCase):
170+
171+
def setUp(self):
172+
# Create a mock VocabParallelEmbedding instance
173+
self.mock_embedding = MagicMock(spec=VocabParallelEmbedding)
174+
self.mock_embedding.tp_size = 2 # Test with tensor parallelism
175+
self.mock_embedding.shard_indices = MagicMock()
176+
self.mock_embedding.shard_indices.org_vocab_start_index = 10
177+
self.mock_embedding.shard_indices.org_vocab_end_index = 20
178+
self.mock_embedding.shard_indices.num_org_vocab_padding = 5
179+
self.mock_embedding.shard_indices.added_vocab_start_index = 30
180+
self.mock_embedding.shard_indices.added_vocab_end_index = 40
181+
self.mock_embedding.quant_method = MagicMock()
182+
183+
# Set consistent embedding dimension for all tests
184+
self.embedding_dim = 10
185+
# Mock embedding returns tensor with shape (input_length, embedding_dim)
186+
self.mock_embedding.quant_method.embedding = MagicMock(
187+
side_effect=lambda _, x: torch.randn(x.shape[0], self.embedding_dim
188+
))
189+
190+
def test_get_masked_input_and_mask(self):
191+
"""Test the mask and offset calculation helper function."""
192+
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes all cases
193+
194+
masked_input, mask = get_masked_input_and_mask(
195+
input_,
196+
org_vocab_start_index=10,
197+
org_vocab_end_index=20,
198+
num_org_vocab_padding=5,
199+
added_vocab_start_index=30,
200+
added_vocab_end_index=40)
201+
202+
# The mask should be True for INVALID tokens (ones we want to mask out)
203+
expected_mask = torch.tensor([True, False, True, False, True])
204+
self.assertTrue(
205+
torch.equal(mask, expected_mask),
206+
f"Mask mismatch. Expected {expected_mask}, got {mask}")
207+
208+
# Check masked input values
209+
expected_masked = torch.tensor([0, 5, 0, 20, 0])
210+
self.assertTrue(
211+
torch.equal(masked_input, expected_masked),
212+
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
213+
)
214+
215+
def test_forward_with_tp_size_1(self):
216+
"""Test forward pass without tensor parallelism."""
217+
# Create a fresh mock embedding with tp_size=1
218+
mock_embedding = MagicMock(spec=VocabParallelEmbedding)
219+
mock_embedding.tp_size = 1
220+
mock_embedding.quant_method = MagicMock()
221+
mock_embedding.quant_method.embedding = MagicMock(
222+
return_value=torch.randn(3, self.embedding_dim))
223+
224+
input_ = torch.tensor([1, 2, 3])
225+
226+
with patch(
227+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
228+
side_effect=lambda x: x) as mock_reduce_tp1:
229+
output = vocab_parallel_embedding_forward(mock_embedding, input_)
230+
231+
# Should just pass through without masking
232+
mock_embedding.quant_method.embedding.assert_called_once_with(
233+
mock_embedding, input_.long())
234+
self.assertEqual(output.shape, (3, self.embedding_dim))
235+
236+
# Verify all_reduce was called once
237+
mock_reduce_tp1.assert_called_once()
238+
239+
def test_forward_with_tp(self):
240+
"""Test forward pass with tensor parallelism."""
241+
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
242+
with patch(
243+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
244+
side_effect=lambda x: x) as mock_reduce_tp:
245+
output = vocab_parallel_embedding_forward(self.mock_embedding,
246+
input_)
247+
248+
# Check that masking was applied correctly
249+
self.mock_embedding.quant_method.embedding.assert_called_once()
250+
called_input = self.mock_embedding.quant_method.embedding.call_args[0][
251+
1]
252+
expected_input = torch.tensor([5, 20]) # after offset calculation
253+
self.assertTrue(torch.all(called_input == expected_input))
254+
255+
# Check that all reduce was called
256+
# self.dist_mock.tensor_model_parallel_all_reduce.assert_called_once()
257+
mock_reduce_tp.assert_called_once()
258+
self.assertEqual(output.shape, (2, self.embedding_dim))
259+
260+
def test_forward_with_invalid_vocab(self):
261+
"""Test that invalid vocab indices are properly masked out."""
262+
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
263+
264+
# Create predictable mock output
265+
mock_output = torch.randn(5, self.embedding_dim)
266+
self.mock_embedding.quant_method.embedding = MagicMock(
267+
return_value=mock_output.clone())
268+
with patch(
269+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
270+
side_effect=lambda x: x):
271+
output = vocab_parallel_embedding_forward(self.mock_embedding,
272+
input_)
273+
274+
# Check that invalid positions (0, 2, 4) were zeroed out
275+
self.assertTrue(torch.all(output[0] == 0))
276+
self.assertTrue(torch.all(output[2] == 0))
277+
self.assertTrue(torch.all(output[4] == 0))
278+
self.assertTrue(torch.all(output[1] == mock_output[1]))
279+
self.assertTrue(torch.all(output[3] == mock_output[3]))
280+
self.assertEqual(output.shape, (5, self.embedding_dim))
281+
282+
def test_output_shape(self):
283+
"""Test that output shape is correct."""
284+
test_cases = [
285+
(torch.tensor([15]), (1, self.embedding_dim)),
286+
(torch.tensor([15, 35]), (2, self.embedding_dim)),
287+
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
288+
]
289+
290+
for input_, expected_shape in test_cases:
291+
with self.subTest(input=input_):
292+
with patch(
293+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
294+
side_effect=lambda x: x):
295+
output = vocab_parallel_embedding_forward(
296+
self.mock_embedding, input_)
297+
self.assertEqual(output.shape, expected_shape)

0 commit comments

Comments
 (0)