Skip to content

Commit 9845f0d

Browse files
authored
【Hackathon 9th No.30】add test_tritonmoe_preprocess (#3891)
* add test_tritonmoe_preprocess * add value check * del test_support_all...
1 parent c4830ef commit 9845f0d

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
21+
22+
23+
class TestTritonMOEPreprocess(unittest.TestCase):
24+
def setUp(self):
25+
paddle.set_device("gpu")
26+
np.random.seed(42)
27+
28+
def _run_op(self, topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M):
29+
"""Convert numpy to Paddle Tensor and run operator"""
30+
topk_ids = paddle.to_tensor(topk_ids_np, dtype="int64")
31+
sorted_ids, expert_ids, num_tokens_post_pad = tritonmoe_preprocess(topk_ids, num_experts, GEMM_BLOCK_SIZE_M)
32+
return sorted_ids.numpy(), expert_ids.numpy(), num_tokens_post_pad.numpy()
33+
34+
def _check_output_shapes(
35+
self, sorted_ids, expert_ids, num_tokens_post_pad, topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M
36+
):
37+
"""Check output shapes and dtypes"""
38+
expected_max_num_tokens_padded = topk_ids_np.size + num_experts * (GEMM_BLOCK_SIZE_M - 1)
39+
self.assertEqual(sorted_ids.shape[0], expected_max_num_tokens_padded)
40+
41+
expected_max_num_m_blocks = expected_max_num_tokens_padded // GEMM_BLOCK_SIZE_M
42+
self.assertEqual(expert_ids.shape[0], expected_max_num_m_blocks)
43+
44+
self.assertEqual(num_tokens_post_pad.shape[0], 1)
45+
self.assertTrue(sorted_ids.dtype == np.int32)
46+
self.assertTrue(expert_ids.dtype == np.int32)
47+
self.assertTrue(num_tokens_post_pad.dtype == np.int32)
48+
49+
def _check_output_values_basic(self, sorted_ids, expert_ids, num_tokens_post_pad):
50+
"""Check expected values for the fixed example"""
51+
expected_sorted_ids = np.array(
52+
[
53+
8,
54+
12,
55+
16,
56+
16,
57+
4,
58+
9,
59+
15,
60+
16,
61+
5,
62+
10,
63+
14,
64+
16,
65+
6,
66+
11,
67+
13,
68+
16,
69+
3,
70+
7,
71+
16,
72+
16,
73+
2,
74+
16,
75+
16,
76+
16,
77+
1,
78+
16,
79+
16,
80+
16,
81+
0,
82+
16,
83+
16,
84+
16,
85+
],
86+
dtype=np.int32,
87+
)
88+
np.testing.assert_array_equal(sorted_ids[: len(expected_sorted_ids)], expected_sorted_ids)
89+
90+
expected_expert_ids = np.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=np.int32)
91+
np.testing.assert_array_equal(expert_ids[: len(expected_expert_ids)], expected_expert_ids)
92+
93+
self.assertTrue(num_tokens_post_pad[0] % 4 == 0)
94+
95+
def test_basic_case(self):
96+
"""Basic fixed example test"""
97+
num_experts = 8
98+
GEMM_BLOCK_SIZE_M = 4
99+
topk_ids_np = np.array([[7, 6, 5, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 3, 2, 1]], dtype=np.int64)
100+
101+
sorted_ids, expert_ids, num_tokens_post_pad = self._run_op(topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M)
102+
self._check_output_shapes(
103+
sorted_ids, expert_ids, num_tokens_post_pad, topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M
104+
)
105+
self._check_output_values_basic(sorted_ids, expert_ids, num_tokens_post_pad)
106+
107+
def test_unsupported_num_experts(self):
108+
"""Test unsupported num_experts raises OSError"""
109+
topk_ids_np = np.array([[0, 1], [1, 0]], dtype=np.int64)
110+
unsupported_experts = [3, 9, 65, 129]
111+
GEMM_BLOCK_SIZE_M = 4
112+
113+
for num_experts in unsupported_experts:
114+
with self.subTest(num_experts=num_experts):
115+
with self.assertRaises(OSError):
116+
self._run_op(topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M)
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

0 commit comments

Comments
 (0)