Skip to content

Commit 33111a7

Browse files
author
chengyuan
committed
add ut for vocab_parallel_embedding
1 parent 4a008c4 commit 33111a7

File tree

1 file changed

+275
-0
lines changed

1 file changed

+275
-0
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
16+
import math
17+
import random
18+
import unittest
19+
from unittest.mock import MagicMock, patch
20+
from typing import Tuple, Optional
21+
22+
import torch
23+
import torch.nn as nn
24+
import pytest
25+
from tests.ut.base import TestBase
26+
from vllm_ascend.ops.vocab_parallel_embedding import (get_masked_input_and_mask, vocab_parallel_embedding_forward)
27+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, UnquantizedEmbeddingMethod
28+
29+
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
30+
31+
class TestGetMaskedInputAndMask(unittest.TestCase):
32+
def setUp(self):
33+
self.input_ = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
34+
35+
def test_get_masked_input_and_mask(self):
36+
# tp 1 no padding
37+
input_modified, _ = get_masked_input_and_mask(self.input_,
38+
org_vocab_start_index=0,
39+
org_vocab_end_index=8,
40+
added_vocab_start_index=8,
41+
added_vocab_end_index=12,
42+
num_org_vocab_padding=0)
43+
assert torch.equal(self.input_, input_modified)
44+
45+
# tp 2 no padding
46+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
47+
org_vocab_start_index=0,
48+
org_vocab_end_index=4,
49+
added_vocab_start_index=8,
50+
added_vocab_end_index=10,
51+
num_org_vocab_padding=0)
52+
53+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
54+
org_vocab_start_index=4,
55+
org_vocab_end_index=8,
56+
added_vocab_start_index=10,
57+
added_vocab_end_index=12,
58+
num_org_vocab_padding=0)
59+
60+
assert torch.equal(input_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
61+
assert torch.equal(input_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
62+
63+
# tp 4 no padding
64+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
65+
org_vocab_start_index=0,
66+
org_vocab_end_index=2,
67+
added_vocab_start_index=8,
68+
added_vocab_end_index=9,
69+
num_org_vocab_padding=0)
70+
71+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
72+
org_vocab_start_index=2,
73+
org_vocab_end_index=4,
74+
added_vocab_start_index=9,
75+
added_vocab_end_index=10,
76+
num_org_vocab_padding=0)
77+
78+
input_rank_2, _ = get_masked_input_and_mask(self.input_,
79+
org_vocab_start_index=4,
80+
org_vocab_end_index=6,
81+
added_vocab_start_index=10,
82+
added_vocab_end_index=11,
83+
num_org_vocab_padding=0)
84+
85+
input_rank_3, _ = get_masked_input_and_mask(self.input_,
86+
org_vocab_start_index=6,
87+
org_vocab_end_index=8,
88+
added_vocab_start_index=11,
89+
added_vocab_end_index=12,
90+
num_org_vocab_padding=0)
91+
assert torch.equal(input_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
92+
assert torch.equal(input_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
93+
assert torch.equal(input_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
94+
assert torch.equal(input_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
95+
96+
# tp 1 with padding
97+
input_modified, _ = get_masked_input_and_mask(self.input_,
98+
org_vocab_start_index=0,
99+
org_vocab_end_index=8,
100+
added_vocab_start_index=8,
101+
added_vocab_end_index=12,
102+
num_org_vocab_padding=2)
103+
assert torch.equal(input_modified, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
104+
105+
# tp 2 with padding
106+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
107+
org_vocab_start_index=0,
108+
org_vocab_end_index=4,
109+
added_vocab_start_index=8,
110+
added_vocab_end_index=10,
111+
num_org_vocab_padding=2)
112+
113+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
114+
org_vocab_start_index=4,
115+
org_vocab_end_index=8,
116+
added_vocab_start_index=10,
117+
added_vocab_end_index=12,
118+
num_org_vocab_padding=2)
119+
assert torch.equal(input_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
120+
assert torch.equal(input_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
121+
122+
# tp 4 with padding
123+
input_rank_0, _ = get_masked_input_and_mask(self.input_,
124+
org_vocab_start_index=0,
125+
org_vocab_end_index=2,
126+
added_vocab_start_index=8,
127+
added_vocab_end_index=9,
128+
num_org_vocab_padding=2)
129+
130+
input_rank_1, _ = get_masked_input_and_mask(self.input_,
131+
org_vocab_start_index=2,
132+
org_vocab_end_index=4,
133+
added_vocab_start_index=9,
134+
added_vocab_end_index=10,
135+
num_org_vocab_padding=2)
136+
137+
input_rank_2, _ = get_masked_input_and_mask(self.input_,
138+
org_vocab_start_index=4,
139+
org_vocab_end_index=6,
140+
added_vocab_start_index=10,
141+
added_vocab_end_index=11,
142+
num_org_vocab_padding=2)
143+
144+
input_rank_3, _ = get_masked_input_and_mask(self.input_,
145+
org_vocab_start_index=6,
146+
org_vocab_end_index=8,
147+
added_vocab_start_index=11,
148+
added_vocab_end_index=12,
149+
num_org_vocab_padding=2)
150+
assert torch.equal(input_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
151+
assert torch.equal(input_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
152+
assert torch.equal(input_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
153+
assert torch.equal(input_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
154+
155+
class TestVocabParallelEmbedding(unittest.TestCase):
156+
157+
def setUp(self):
158+
# Create a mock VocabParallelEmbedding instance
159+
self.mock_embedding = MagicMock(spec=VocabParallelEmbedding)
160+
self.mock_embedding.tp_size = 2 # Test with tensor parallelism
161+
self.mock_embedding.shard_indices = MagicMock()
162+
self.mock_embedding.shard_indices.org_vocab_start_index = 10
163+
self.mock_embedding.shard_indices.org_vocab_end_index = 20
164+
self.mock_embedding.shard_indices.num_org_vocab_padding = 5
165+
self.mock_embedding.shard_indices.added_vocab_start_index = 30
166+
self.mock_embedding.shard_indices.added_vocab_end_index = 40
167+
self.mock_embedding.quant_method = MagicMock()
168+
169+
# Set consistent embedding dimension for all tests
170+
self.embedding_dim = 10
171+
# Mock embedding returns tensor with shape (input_length, embedding_dim)
172+
self.mock_embedding.quant_method.embedding = MagicMock(
173+
side_effect=lambda _, x: torch.randn(x.shape[0], self.embedding_dim)
174+
)
175+
176+
def test_get_masked_input_and_mask(self):
177+
"""Test the mask and offset calculation helper function."""
178+
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes all cases
179+
180+
masked_input, mask = get_masked_input_and_mask(
181+
input_,
182+
org_vocab_start_index=10,
183+
org_vocab_end_index=20,
184+
num_org_vocab_padding=5,
185+
added_vocab_start_index=30,
186+
added_vocab_end_index=40
187+
)
188+
189+
# The mask should be True for INVALID tokens (ones we want to mask out)
190+
expected_mask = torch.tensor([True, False, True, False, True])
191+
self.assertTrue(torch.equal(mask, expected_mask),
192+
f"Mask mismatch. Expected {expected_mask}, got {mask}")
193+
194+
# Check masked input values
195+
expected_masked = torch.tensor([0, 5, 0, 20, 0])
196+
self.assertTrue(torch.equal(masked_input, expected_masked),
197+
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}")
198+
199+
def test_forward_with_tp_size_1(self):
200+
"""Test forward pass without tensor parallelism."""
201+
# Create a fresh mock embedding with tp_size=1
202+
mock_embedding = MagicMock(spec=VocabParallelEmbedding)
203+
mock_embedding.tp_size = 1
204+
mock_embedding.quant_method = MagicMock()
205+
mock_embedding.quant_method.embedding = MagicMock(
206+
return_value=torch.randn(3, self.embedding_dim)
207+
)
208+
209+
input_ = torch.tensor([1, 2, 3])
210+
211+
with patch("vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x) as mock_reduce_tp1:
212+
output = vocab_parallel_embedding_forward(mock_embedding, input_)
213+
214+
# Should just pass through without masking
215+
mock_embedding.quant_method.embedding.assert_called_once_with(
216+
mock_embedding, input_.long()
217+
)
218+
self.assertEqual(output.shape, (3, self.embedding_dim))
219+
220+
# Verify all_reduce was called once
221+
mock_reduce_tp1.assert_called_once()
222+
223+
def test_forward_with_tp(self):
224+
"""Test forward pass with tensor parallelism."""
225+
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
226+
with patch("vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x) as mock_reduce_tp:
227+
output = vocab_parallel_embedding_forward(self.mock_embedding, input_)
228+
229+
# Check that masking was applied correctly
230+
self.mock_embedding.quant_method.embedding.assert_called_once()
231+
called_input = self.mock_embedding.quant_method.embedding.call_args[0][1]
232+
expected_input = torch.tensor([5, 20]) # after offset calculation
233+
self.assertTrue(torch.all(called_input == expected_input))
234+
235+
# Check that all reduce was called
236+
# self.dist_mock.tensor_model_parallel_all_reduce.assert_called_once()
237+
mock_reduce_tp.assert_called_once()
238+
self.assertEqual(output.shape, (2, self.embedding_dim))
239+
240+
def test_forward_with_invalid_vocab(self):
241+
"""Test that invalid vocab indices are properly masked out."""
242+
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
243+
244+
# Create predictable mock output
245+
mock_output = torch.randn(5, self.embedding_dim)
246+
self.mock_embedding.quant_method.embedding = MagicMock(return_value=mock_output.clone())
247+
with patch("vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x) as mock_reduce:
248+
output = vocab_parallel_embedding_forward(self.mock_embedding, input_)
249+
250+
# Check that invalid positions (0, 2, 4) were zeroed out
251+
self.assertTrue(torch.all(output[0] == 0))
252+
self.assertTrue(torch.all(output[2] == 0))
253+
self.assertTrue(torch.all(output[4] == 0))
254+
self.assertTrue(torch.all(output[1] == mock_output[1]))
255+
self.assertTrue(torch.all(output[3] == mock_output[3]))
256+
self.assertEqual(output.shape, (5, self.embedding_dim))
257+
258+
def test_output_shape(self):
259+
"""Test that output shape is correct."""
260+
test_cases = [
261+
(torch.tensor([15]), (1, self.embedding_dim)),
262+
(torch.tensor([15, 35]), (2, self.embedding_dim)),
263+
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
264+
]
265+
266+
for input_, expected_shape in test_cases:
267+
with self.subTest(input=input_):
268+
with patch("vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x) as mock_reduce:
269+
output = vocab_parallel_embedding_forward(self.mock_embedding, input_)
270+
self.assertEqual(output.shape, expected_shape)
271+
272+
273+
274+
275+

0 commit comments

Comments
 (0)