@@ -110,143 +110,6 @@ def test_fused_moe(
110
110
rtol = 0 )
111
111
112
112
113
- def torch_dispatch (
114
- a : torch .Tensor ,
115
- topk_ids : torch .Tensor ,
116
- num_experts : int
117
- ) -> torch .Tensor :
118
- assert topk_ids .dim () == 2
119
- assert topk_ids .shape [0 ] == a .shape [0 ]
120
-
121
- num_tokens = a .shape [0 ]
122
- topk = topk_ids .shape [1 ]
123
-
124
- tokens_per_expert = torch .bincount (topk_ids .view (- 1 ), minlength = num_experts )
125
-
126
- max_num_tokens = tokens_per_expert .max ()
127
- b_a = torch .zeros ((num_experts , max_num_tokens , a .shape [1 ]),
128
- dtype = a .dtype , device = a .device )
129
- #print(f"b_a shape {b_a.shape}")
130
-
131
- token_counts = torch .zeros (num_experts , dtype = torch .int , device = a .device )
132
-
133
- for token in range (num_tokens ):
134
- for j in range (topk ):
135
- expert_id = topk_ids [token , j ]
136
- idx = token_counts [expert_id ]
137
- b_a [expert_id , idx :idx + 1 , :] = a [token , :]
138
- token_counts [expert_id ] = token_counts [expert_id ] + 1
139
-
140
- return b_a , tokens_per_expert
141
-
142
-
143
- def torch_combine (b_out , topk_weight , topk_ids ):
144
- num_tokens , topk = topk_ids .shape
145
- num_experts = b_out .shape [0 ]
146
- K = b_out .shape [- 1 ]
147
- out = torch .zeros ((num_tokens , K ), dtype = b_out .dtype , device = b_out .device )
148
- expert_counts = torch .zeros (num_experts , dtype = torch .int , device = b_out .device )
149
- for token in range (num_tokens ):
150
- expert_ids = topk_ids [token ]
151
- for i in range (expert_ids .numel ()):
152
- expert_id = expert_ids [i ]
153
- idx = expert_counts [expert_id ]
154
- out [token , :] = out [token , :] + b_out [expert_id , idx :idx + 1 , :] * topk_weight [token , i ]
155
- expert_counts [expert_id ] = expert_counts [expert_id ] + 1
156
-
157
- return out
158
-
159
-
160
- def torch_batched_moe (a , w1 , w2 , topk_weight , topk_ids ):
161
- num_experts = w1 .shape [0 ]
162
- b_a , tokens_per_expert = torch_dispatch (a , topk_ids , num_experts )
163
- assert b_a .dim () == 3
164
- num_tokens , topk = topk_ids .shape
165
- _ , max_num_tokens , K = b_a .shape
166
- assert num_experts == b_a .shape [0 ] and K == w2 .shape [1 ]
167
- out = torch .zeros ((num_experts , max_num_tokens , K ), dtype = b_a .dtype , device = b_a .device )
168
- tmp = torch .empty ((max_num_tokens , w1 .shape [1 ] // 2 ), dtype = b_a .dtype , device = b_a .device )
169
- for expert in range (num_experts ):
170
- num = tokens_per_expert [expert ]
171
- if num > 0 :
172
- torch .ops ._C .silu_and_mul (tmp [:num ], b_a [expert ,:num ,:] @ w1 [expert ].transpose (0 , 1 ))
173
- out [expert , :num , :] = tmp [:num ] @ w2 [expert ].transpose (0 , 1 )
174
-
175
- return torch_combine (out , topk_weight , topk_ids )
176
-
177
-
178
- # TODO: same as torch_moe but with fused_topk factored out.
179
- def torch_moe2 (a , w1 , w2 , topk_weight , topk_ids ):
180
- M , K = a .shape
181
- topk = topk_ids .shape [1 ]
182
- a = a .view (M , - 1 , K ).repeat (1 , topk , 1 ).reshape (- 1 , K )
183
- out = torch .zeros (M * topk , w2 .shape [1 ], dtype = a .dtype , device = a .device )
184
- num_experts = w1 .shape [0 ]
185
- for i in range (num_experts ):
186
- mask = (topk_ids == i ).view (- 1 )
187
- if mask .sum ():
188
- out [mask ] = SiluAndMul ()(
189
- a [mask ] @ w1 [i ].transpose (0 , 1 )) @ w2 [i ].transpose (0 , 1 )
190
-
191
- return (out .view (M , - 1 , w2 .shape [1 ]) *
192
- topk_weight .view (M , - 1 , 1 ).to (out .dtype )).sum (dim = 1 )
193
-
194
-
195
- @pytest .mark .parametrize ("m" , [1 , 33 , 64 , 222 ]) #, 1024 * 128])
196
- @pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
197
- @pytest .mark .parametrize ("k" , [128 , 511 , 1024 ])
198
- @pytest .mark .parametrize ("e" , NUM_EXPERTS )
199
- @pytest .mark .parametrize ("topk" , TOP_KS )
200
- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
201
- def test_fused_moe_batched_experts (
202
- m : int ,
203
- n : int ,
204
- k : int ,
205
- e : int ,
206
- topk : int ,
207
- dtype : torch .dtype ,
208
- ):
209
- current_platform .seed_everything (7 )
210
-
211
- a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
212
- w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
213
- w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
214
-
215
- score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
216
-
217
- vllm_config = VllmConfig ()
218
- with set_current_vllm_config (vllm_config ):
219
- topk_weight , topk_ids = fused_topk (a , score , topk , False )
220
-
221
- torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
222
-
223
- if True :
224
- triton_output = torch_batched_moe (a ,
225
- w1 ,
226
- w2 ,
227
- topk_weight ,
228
- topk_ids )
229
- else :
230
- b_a , tokens_per_expert = batch_by_experts (a , topk_ids , e )
231
- triton_output = fused_batched_experts (
232
- b_a ,
233
- w1 ,
234
- w2 ,
235
- topk_weight ,
236
- topk_ids ,
237
- global_num_experts = e
238
- )
239
-
240
- if False :
241
- torch .set_printoptions (profile = "full" )
242
- print ("BASELINE" )
243
- print (torch_output )
244
- print ("OUTPUT" )
245
- print (triton_output )
246
-
247
- torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
248
-
249
-
250
113
@pytest .mark .parametrize ("m" , [1 , 32 , 222 ])
251
114
@pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
252
115
@pytest .mark .parametrize ("k" , [128 , 1024 ])
@@ -587,7 +450,8 @@ def test_fused_marlin_moe(
587
450
topk_weights , topk_ids , token_expert_indices = fused_topk (
588
451
a , score , topk , False )
589
452
590
- torch_output = torch_moe (a , w_ref1 , w_ref2 , score , topk , e_map )
453
+ with set_current_vllm_config (vllm_config ):
454
+ torch_output = torch_moe (a , w_ref1 , w_ref2 , score , topk , e_map )
591
455
592
456
marlin_output = torch .ops .vllm .fused_marlin_moe (
593
457
a ,
0 commit comments