@@ -144,12 +144,13 @@ def prepare(
144
144
"apply_router_weight_on_input is only implemented for topk=1" )
145
145
a1 = a1 * topk_weights .to (a1 .dtype )
146
146
147
- if quant_config .per_act_token_quant :
147
+ if quant_config .is_block_quantized :
148
+ # Quant and Dispatch
148
149
a1q , a1q_scale = moe_kernel_quantize_input (
149
150
a1 ,
150
151
a1_scale ,
151
152
quant_dtype = quant_config .quant_dtype ,
152
- per_act_token_quant = True ,
153
+ per_act_token_quant = quant_config . per_act_token_quant ,
153
154
block_shape = quant_config .block_shape ,
154
155
)
155
156
if a1q_scale is not None and a1q_scale .numel () == 1 :
@@ -162,16 +163,18 @@ def prepare(
162
163
rank_topk_weights = topk_weights ,
163
164
num_experts = num_experts )
164
165
else :
165
- # DeepEP kernels only support dispatching per-token-quant
166
- # quantization. dispatch in bfloat16.
166
+ # Dispatch and Quant
167
+ # DeepEP kernels only support dispatching block-quantized
168
+ # activation scales.
169
+ # Dispatch in bfloat16
167
170
(expert_x , _ , expert_tokens_meta , expert_topk_ids ,
168
171
expert_topk_weights ) = self ._do_dispatch (
169
172
tokens = a1 ,
170
173
token_scales = None ,
171
174
rank_topk_ids = topk_ids ,
172
175
rank_topk_weights = topk_weights ,
173
176
num_experts = num_experts )
174
- # quantize now
177
+ # Quantize after dispatch.
175
178
expert_x_scale = None
176
179
if expert_x .numel () != 0 :
177
180
expert_x , expert_x_scale = moe_kernel_quantize_input (
0 commit comments