Skip to content

Commit ea250c8

Browse files
committed
Add moe module level changes
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent 6139747 commit ea250c8

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _check_configs(self):
179179
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
180180

181181
if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
182-
assert self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports mxfp4 quantization with bias, swiglu_alpha, swiglu_beta and swiglu_limit."
182+
assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants."
183183

184184
def _get_quant_method(self):
185185
if self.quant_config is not None:
@@ -213,7 +213,7 @@ def create_weights(self):
213213
self._weights_created = True
214214
self._check_configs()
215215

216-
if (self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8
216+
if (self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8
217217
or self.has_w4a8_mxfp4_fp8
218218
or self.has_w4a8_mxfp4_mxfp8) and not self.bias:
219219
self.w3_w1_bias = nn.Parameter(torch.zeros(
@@ -449,8 +449,13 @@ def forward_impl(
449449
hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn),
450450
self.w3_w1_weight,
451451
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
452+
self.w3_w1_bias,
453+
self.swiglu_alpha,
454+
self.swiglu_beta,
455+
self.swiglu_limit,
452456
self.w2_weight,
453457
self.w2_weight_scale.view(torch.float8_e4m3fn),
458+
self.w2_bias,
454459
self.fc31_scale_c.data,
455460
self.fc31_alpha.data,
456461
self.fc2_alpha.data,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,8 @@ def create_weights(self,
15261526
weight_vec_size,
15271527
block_scales_dtype,
15281528
block_scales_vec_size,
1529-
scaling_vector_size=16):
1529+
scaling_vector_size=16,
1530+
bias_dtype: Optional[torch.dtype] = None):
15301531

15311532
module.scaling_vector_size = scaling_vector_size
15321533
# Divide by 16 because we use int64 to pack 16 fp4 values
@@ -1576,8 +1577,11 @@ def create_weights(self,
15761577
requires_grad=False)
15771578
module.register_parameter("fc2_alpha", fc2_alpha)
15781579

1579-
super().create_weights(module, weight_dtype, w3_w1_weight_shape,
1580-
w2_weight_shape)
1580+
super().create_weights(module,
1581+
weight_dtype,
1582+
w3_w1_weight_shape,
1583+
w2_weight_shape,
1584+
bias_dtype=bias_dtype)
15811585

15821586
self.setup_quant_scales(module)
15831587

@@ -1856,8 +1860,12 @@ def create_weights(self, module: torch.nn.Module):
18561860
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
18571861
block_scales_vec_size = 1
18581862

1859-
super().create_weights(module, self.weight_dtype, weight_vec_size,
1860-
self.block_scales_dtype, block_scales_vec_size)
1863+
super().create_weights(module,
1864+
self.weight_dtype,
1865+
weight_vec_size,
1866+
self.block_scales_dtype,
1867+
block_scales_vec_size,
1868+
bias_dtype=torch.float32)
18611869

18621870
fc31_scale_c = nn.Parameter(torch.ones(module.expert_size_per_partition,
18631871
dtype=torch.float32),

0 commit comments

Comments
 (0)