@@ -263,11 +263,13 @@ def trtllm_quant_nvfp4_moe_fused(
263263 act_fn: "silu" for gated_mlp, "relu2" for mlp
264264 """
265265 NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
266- TRTLLM_NVFP4_NUM_ELEMENTS_PER_UINT8 = 2
266+ FP4_PER_UINT8 = 2
267267
268+ _ , fc1_inter_size , _ = fc1_expert_weights_fp4 .shape
268269 n_experts , hidden_size , inter_size = fc2_expert_weights_fp4 .shape
270+
269271 # Convert the inter_size from number of uint8 elements to number of FP4 elements.
270- inter_size *= TRTLLM_NVFP4_NUM_ELEMENTS_PER_UINT8
272+ inter_size *= FP4_PER_UINT8
271273
272274 # Validate shapes and padding requirements as defined by the cutlass kernel.
273275 assert fc1_weight_blockscale_fp8 .ndim == 3 , "fc1_weight_blockscale_fp8 must be 3D"
@@ -290,28 +292,42 @@ def trtllm_quant_nvfp4_moe_fused(
290292 input_blockscale = None
291293 output_dtype = x .dtype
292294
293- # Pad I to be divisible by 128
295+ # Pad inter_size to be divisible by 128
294296 inter_size_padded = math .ceil (inter_size / TRTLLM_NVFP4_ROW_SIZE ) * TRTLLM_NVFP4_ROW_SIZE
295- if not is_gated_mlp and inter_size_padded != inter_size :
296- # if False:
297- # fc1_expert_weights_fp4: [E, I, H]
297+ fc1_inter_size_padded = (
298+ math .ceil (fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE ) * TRTLLM_NVFP4_ROW_SIZE
299+ )
300+ hidden_size_padded = (
301+ math .ceil (hidden_size / TRTLLM_NVFP4_COLUMN_SIZE ) * TRTLLM_NVFP4_COLUMN_SIZE
302+ )
303+
304+ inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size ) or (
305+ not is_gated_mlp and inter_size_padded != inter_size
306+ )
307+ hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
308+ if inter_size_needs_padding or hidden_size_needs_padding :
309+ # fc1_expert_weights_fp4: [E, I, H] or [E, 2*I, H]
298310 fc1_padded = fc1_expert_weights_fp4 .new_zeros (
299- n_experts , inter_size_padded , hidden_size // 2
311+ fc1_expert_weights_fp4 .size (0 ),
312+ fc1_inter_size_padded ,
313+ hidden_size_padded // FP4_PER_UINT8 ,
300314 )
301- fc1_padded [:, :inter_size , :] = fc1_expert_weights_fp4
315+ fc1_padded [:, :fc1_inter_size , :] = fc1_expert_weights_fp4
302316 fc1_expert_weights_fp4 = fc1_padded
303317
304318 # fc2_expert_weights_fp4: [E, H, I]
305319 fc2_padded = fc2_expert_weights_fp4 .new_zeros (
306- n_experts , hidden_size , inter_size_padded // 2
320+ n_experts , hidden_size_padded , inter_size_padded // FP4_PER_UINT8
307321 )
308- fc2_padded [:, :, : inter_size // 2 ] = fc2_expert_weights_fp4
322+ fc2_padded [:, :, : inter_size // FP4_PER_UINT8 ] = fc2_expert_weights_fp4
309323 fc2_expert_weights_fp4 = fc2_padded
310324
311325 fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8 .new_zeros (
312- n_experts , hidden_size , inter_size_padded // 16
326+ n_experts , hidden_size_padded , inter_size_padded // NVFP4_BLOCK_SIZE
327+ )
328+ fc2_blockscale_fp8_padded [:, :, : inter_size // NVFP4_BLOCK_SIZE ] = (
329+ fc2_weight_blockscale_fp8
313330 )
314- fc2_blockscale_fp8_padded [:, :, : inter_size // 16 ] = fc2_weight_blockscale_fp8
315331 fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded
316332
317333 # quant_scales is described by this code:
0 commit comments