diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index c9d9085614a..4d23df1b5be 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -34,15 +34,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN; std::vector run_fp4_block_scale_moe_runner(torch::optional const& routing_logits, torch::optional const& routing_bias, torch::Tensor const& hidden_states, torch::optional const& hidden_states_scale, torch::Tensor const& gemm1_weights, - torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights, - torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar, - torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar, - int64_t const num_experts, int64_t const top_k, std::optional const n_group, - std::optional const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, - int64_t const local_num_experts, std::optional const routed_scaling_factor, int64_t const tile_tokens_dim, - int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, - int64_t const moeConfigIndex, torch::optional const& topk_weights, - torch::optional const& topk_ids) + torch::Tensor const& gemm1_weights_scale, std::optional const& gemm1_bias, + std::optional const& gemm1_alpha, std::optional const& gemm1_beta, + std::optional const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights, + torch::Tensor const& gemm2_weights_scale, std::optional const& gemm2_bias, + torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar, + torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k, + std::optional const n_group, std::optional const topk_group, int64_t const intermediate_size, + int64_t const local_expert_offset, int64_t const local_num_experts, + std::optional const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, + bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex, + torch::optional const& topk_weights, torch::optional const& topk_ids) { TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1."); TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE"); @@ -161,8 +163,13 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional() : nullptr; + args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr() : nullptr; + args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr() : nullptr; + args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr() : nullptr; args.gemm2_weights = gemm2_weights.data_ptr(); args.gemm2_weights_scale = gemm2_weights_scale.data_ptr(); + args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr() : nullptr; args.num_tokens = hidden_states.sizes()[0]; args.num_experts = num_experts; if (dtype == btg::Dtype::E4m3) @@ -313,6 +320,38 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional run_fp4_block_scale_moe_runner(torch::optional run(torch::optional const& routing_logits, torch::optional const& routing_bias, torch::Tensor const& hidden_states, torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights, - torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights, - torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar, - torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar, - int64_t const num_experts, int64_t const top_k, std::optional const n_group, - std::optional const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, - int64_t const local_num_experts, std::optional const routed_scaling_factor, - int64_t const routing_method_type, bool const do_finalize, std::vector moeConfigIndex, - torch::optional const& topk_weights, torch::optional const& topk_ids) + torch::Tensor const& gemm1_weights_scale, std::optional const& gemm1_bias, + std::optional const& gemm1_alpha, std::optional const& gemm1_beta, + std::optional const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights, + torch::Tensor const& gemm2_weights_scale, std::optional const& gemm2_bias, + torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar, + torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k, + std::optional const n_group, std::optional const topk_group, int64_t const intermediate_size, + int64_t const local_expert_offset, int64_t const local_num_experts, + std::optional const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize, + std::vector moeConfigIndex, torch::optional const& topk_weights, + torch::optional const& topk_ids) { // moeConfigIndex corresponds to pair (tileN, config) auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]); @@ -468,10 +519,11 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder } return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale, - gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar, - output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group, - intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN, - routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config, topk_weights, topk_ids); + gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, + gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, + num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, + routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config, + topk_weights, topk_ids); } private: @@ -553,11 +605,11 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder } return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, - std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, gemm2_weights, - gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts, - top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, - topk_weights, topk_ids); + std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar, + output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group, + intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN, + routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids); } private: diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index a8236d88fcf..31107f752cd 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -176,8 +176,13 @@ class FP4BlockScaleMoEInputs: hidden_states_scale: torch.Tensor gemm1_weights: torch.Tensor gemm1_weights_scale: torch.Tensor + gemm1_bias: torch.Tensor + gemm1_alpha: torch.Tensor + gemm1_beta: torch.Tensor + gemm1_clamp_limit: torch.Tensor gemm2_weights: torch.Tensor gemm2_weights_scale: torch.Tensor + gemm2_bias: torch.Tensor output1_scale_scalar: torch.Tensor output1_scale_gate_scalar: torch.Tensor output2_scale_scalar: torch.Tensor @@ -235,14 +240,15 @@ def forward( return kernel_runner.run_moe( args.routing_logits, args.routing_bias, args.hidden_states, args.hidden_states_scale, args.gemm1_weights, - args.gemm1_weights_scale, args.gemm2_weights, - args.gemm2_weights_scale, args.output1_scale_scalar, - args.output1_scale_gate_scalar, args.output2_scale_scalar, - self.num_experts, self.top_k, self.n_group, self.topk_group, - self.intermediate_size, self.local_expert_offset, - self.local_num_experts, self.routed_scaling_factor, - self.routing_method_type, self.do_finalize, tactic, - args.topk_weights, args.topk_ids) + args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha, + args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights, + args.gemm2_weights_scale, args.gemm2_bias, + args.output1_scale_scalar, args.output1_scale_gate_scalar, + args.output2_scale_scalar, self.num_experts, self.top_k, + self.n_group, self.topk_group, self.intermediate_size, + self.local_expert_offset, self.local_num_experts, + self.routed_scaling_factor, self.routing_method_type, + self.do_finalize, tactic, args.topk_weights, args.topk_ids) def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, @@ -359,8 +365,13 @@ def fp4_block_scale_moe_runner( hidden_states_scale: torch.Tensor, gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, + gemm1_bias: torch.Tensor, + gemm1_alpha: torch.Tensor, + gemm1_beta: torch.Tensor, + gemm1_clamp_limit: torch.Tensor, gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, + gemm2_bias: torch.Tensor, output1_scale_scalar: torch.Tensor, output1_scale_gate_scalar: torch.Tensor, output2_scale_scalar: torch.Tensor, @@ -416,8 +427,13 @@ def fp4_block_scale_moe_runner( hidden_states_scale, gemm1_weights, gemm1_weights_scale, + gemm1_bias, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, + gemm2_bias, output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, @@ -474,8 +490,13 @@ def _(routing_logits, hidden_states_scale, gemm1_weights, gemm1_weights_scale, + gemm1_bias, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, + gemm2_bias, output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py index aeec74e5e65..5e0dc9a486c 100644 --- a/tensorrt_llm/_torch/models/modeling_gpt_oss.py +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -34,8 +34,7 @@ from ..speculative import SpecMetadata from ..utils import Fp4QuantizedTensor from .modeling_speculative import SpecDecOneEngineForCausalLM -from .modeling_utils import (DecoderModel, duplicate_kv_weight, filter_weights, - register_auto_model) +from .modeling_utils import DecoderModel, filter_weights, register_auto_model # Use TinyGEMM when the number of tokens is not larger than this threshold MIN_LATENCY_TINYGEMM_NUM_TOKENS = 128 @@ -639,6 +638,15 @@ def __post_init__(self): quant_config = self.model_config.quant_config if quant_config.exclude_modules: + if quant_config.quant_algo == "NVFP4": + quant_config.exclude_modules = [ + 'block.*.attn.qkv', + 'block.*.attn.out', + 'block.*.mlp.gate', + 'embedding', + 'unembedding', + ] + for i, module in enumerate(quant_config.exclude_modules): names = module.split(".") if names[-1] in params_map_reverse: @@ -653,13 +661,10 @@ def __post_init__(self): module.create_weights() def load_weights(self, weights: Dict): - is_ori_model = True - for k, v in weights.items(): - if 'q_proj' in k: - is_ori_model = False + is_nvfp4 = self.model_config.quant_config.quant_mode.has_nvfp4() - if is_ori_model: - self.load_ori_weights(weights) + if is_nvfp4: + self.load_nvfp4_weights(weights) else: self.load_hf_weights(weights) @@ -811,176 +816,107 @@ def load_hf_weights(self, weights: Dict): if p is not None: p.data.copy_(module_weights[n][:]) - def load_ori_weights(self, weights: Dict): - head_dim = self.config.head_dim - num_q_head = self.config.num_attention_heads - num_kv_head = self.config.num_key_value_heads + def load_nvfp4_weights(self, weights: Dict): num_expert = self.config.num_local_experts - enable_attention_dp = self.model_config.mapping.enable_attention_dp - tp_size = self.model_config.mapping.tp_size for name, module in tqdm(list(self.named_modules()), desc="Loading weights"): if len(module._parameters) <= 0 or name.startswith("draft_model"): continue - names = name.split(".") + module_weights = {} - if names[-1] in self.params_map: - names[-1] = self.params_map[names[-1]] + for k, v in self.hf_params_map.items(): + name = name.replace(k, v) + + names = name.split('.') + if names[-1] == "backend" and isinstance(module, MoE): + # Backend is under experts module (ConfigurableMoE wrapper) + name = '.'.join(names[:-1]) - # Drop the first "model" prefix - if names[0] == 'model': - name = '.'.join(names[1:]) - else: - name = '.'.join(names) module_weights = filter_weights(name, weights) + if isinstance(module, MoE): - # [num_experts, intermediate_size * 2, hidden_size] - gate_up_proj = filter_weights(name.replace("experts", "mlp1"), - weights) - # [num_experts, intermediate_size, hidden_size] - down_proj = filter_weights(name.replace("experts", "mlp2"), - weights) - try: - # Official MXFP4 ckpt. - gate_up_weight = gate_up_proj['weight.blocks'].flatten( - -2, -1) - gate, up = gate_up_weight[:, ::2, :], gate_up_weight[:, 1:: - 2, :] - gate_up_weight = torch.cat([gate, up], dim=-2) - gate_up_bias = gate_up_proj['bias'] - gate, up = gate_up_bias[:, ::2], gate_up_bias[:, 1::2] - gate_up_bias = torch.cat([gate, up], dim=-1) - moe_weights = { - 'gate_up_proj': [ - gate_up_weight[i, :, :].transpose(0, 1) - for i in range(num_expert) - ], - 'down_proj': [ - down_proj['weight.blocks'].flatten( - -2, -1)[i, :, :].transpose(0, 1) - for i in range(num_expert) - ], - 'gate_up_proj.bias': - [gate_up_bias[i, :] for i in range(num_expert)], - 'down_proj.bias': - [down_proj['bias'][i, :] for i in range(num_expert)] - } - except: - # For BF16 ckpt. - moe_weights = { - 'gate_up_proj': [ - gate_up_proj['weight'][i, :, :].transpose(0, 1).to( - self.model.dtype) for i in range(num_expert) - ], - 'down_proj': [ - down_proj['weight'][i, :, :].transpose(0, 1).to( - self.model.dtype) for i in range(num_expert) - ], - 'gate_up_proj.bias': - [gate_up_proj['bias'][i, :] for i in range(num_expert)], - 'down_proj.bias': - [down_proj['bias'][i, :] for i in range(num_expert)] - } - # Only for Official MXFP4 ckpt. - if 'weight.scales' in gate_up_proj: - gate_up_weight_scale = gate_up_proj['weight.scales'] - gate, up = gate_up_weight_scale[:, :: - 2, :], gate_up_weight_scale[:, - 1:: - 2, :] - gate_up_weight_scale = torch.cat([gate, up], dim=-2) - moe_weights['gate_up_proj_weight_scale'] = [ - gate_up_weight_scale[i, :, :].transpose(0, 1) - for i in range(num_expert) + assert getattr(module, "quant_config", None) is not None and \ + module.quant_config.quant_mode.has_nvfp4() + gate_up = module_weights.get('gate_up_proj', None) + down = module_weights.get('down_proj', None) + gate_up_bias = module_weights.get('gate_up_proj_bias', None) + down_bias = module_weights.get('down_proj_bias', None) + + def deinterleave(tensor): + g, u = tensor[..., ::2], tensor[..., 1::2] + return torch.cat([g, u], dim=-1) + + gate_up = deinterleave(gate_up) + gate_up_bias = deinterleave(gate_up_bias) + + # Only fp32 bias is supported for NVFP4 MoE. + if gate_up_bias.dtype != torch.float32: + gate_up_bias = gate_up_bias.to(torch.float32) + if down_bias.dtype != torch.float32: + down_bias = down_bias.to(torch.float32) + + moe_weights = {} + if gate_up is not None: + moe_weights['gate_up_proj'] = [ + gate_up[i, :, :] for i in range(num_expert) + ] + if down is not None: + moe_weights['down_proj'] = [ + down[i, :, :] for i in range(num_expert) + ] + if gate_up_bias is not None: + moe_weights['gate_up_proj.bias'] = [ + gate_up_bias[i, :] for i in range(num_expert) + ] + if down_bias is not None: + moe_weights['down_proj.bias'] = [ + down_bias[i, :] for i in range(num_expert) ] - if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4': - for i in range(num_expert): - moe_weights[f"{i}.w1.weight_scale_inv"] = gate[ - i, :, :] - moe_weights[f"{i}.w3.weight_scale_inv"] = up[ - i, :, :] - - if 'weight.scales' in down_proj: + # Per-expert block scales (transpose to expected layout) + if 'gate_up_proj_weight_scale' in module_weights: + gu_ws = module_weights['gate_up_proj_weight_scale'] + gu_ws = deinterleave(gu_ws) + moe_weights['gate_up_proj_weight_scale'] = [ + gu_ws[i, :, :] for i in range(num_expert) + ] + if 'down_proj_weight_scale' in module_weights: + dp_ws = module_weights['down_proj_weight_scale'] moe_weights['down_proj_weight_scale'] = [ - down_proj['weight.scales'][i, :, :].transpose(0, 1) - for i in range(num_expert) + dp_ws[i, :, :] for i in range(num_expert) ] - if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4': - for i in range(num_expert): - moe_weights[f"{i}.w2.weight_scale_inv"] = down_proj[ - 'weight.scales'][i, :, :] + # Module-level globals for NVFP4 loaders + for src_key in [ + 'gate_up_proj_weight_scale_2', + 'down_proj_weight_scale_2', + 'gate_up_proj_input_scale', + 'down_proj_input_scale', + ]: + if src_key in module_weights: + moe_weights[src_key] = module_weights[src_key] module.load_weights(weights=[moe_weights]) elif hasattr(module, "load_weights"): - # Load Attention module weights. if 'qkv' in name: - q_weight = module_weights['weight'][:head_dim * - num_q_head, :] - k_weight = module_weights['weight'][head_dim * - num_q_head:head_dim * - (num_q_head + - num_kv_head), :] - v_weight = module_weights['weight'][-head_dim * - num_kv_head:, :] - q_bias = module_weights['bias'][:head_dim * num_q_head] - k_bias = module_weights['bias'][head_dim * - num_q_head:head_dim * - (num_q_head + num_kv_head)] - v_bias = module_weights['bias'][-head_dim * num_kv_head:] - - # Handle KV weight duplication for GQA - tensors_need_duplication = ['weight', 'bias'] - if module.quant_config.quant_mode.has_mxfp4(): - tensors_need_duplication.append('weight_scale') - - # Duplicate KV weights if needed - tensor_parallel_size = tp_size if not enable_attention_dp else 1 - - k_weight_dict = {'weight': k_weight, 'bias': k_bias} - v_weight_dict = {'weight': v_weight, 'bias': v_bias} - - if 'weight_scale' in module_weights: - k_weight_dict['weight_scale'] = module_weights[ - 'weight_scale'][head_dim * num_q_head:head_dim * - (num_q_head + num_kv_head), :] - v_weight_dict['weight_scale'] = module_weights[ - 'weight_scale'][-head_dim * num_kv_head:, :] - - k_weight_dict = { - k: (duplicate_kv_weight( - weight=v, - num_kv_heads=num_kv_head, - tensor_parallel_size=tensor_parallel_size) - if k in tensors_need_duplication else v) - for k, v in k_weight_dict.items() - } - - v_weight_dict = { - k: (duplicate_kv_weight( - weight=v, - num_kv_heads=num_kv_head, - tensor_parallel_size=tensor_parallel_size) - if k in tensors_need_duplication else v) - for k, v in v_weight_dict.items() - } - - qkv_weights = [{ - 'weight': q_weight, - 'bias': q_bias - }, k_weight_dict, v_weight_dict] - module.load_weights(weights=qkv_weights) + # For qkv_proj + q_weight_bias = filter_weights( + name.replace('qkv_proj', 'q_proj'), weights) + k_weight_bias = filter_weights( + name.replace('qkv_proj', 'k_proj'), weights) + v_weight_bias = filter_weights( + name.replace('qkv_proj', 'v_proj'), weights) + module.load_weights( + weights=[q_weight_bias, k_weight_bias, v_weight_bias]) else: - # Dense & gate & sinks + # For o_proj, sinks. module.load_weights(weights=[module_weights]) else: - # Load LN weights. - if names[-1].endswith("layernorm") and names[-3] == "block": - # skip loading weights for the fused norms + # Load four LN weights (attn.norm, mlp.norm, input_layernorm, post_attention_layernorm). + if 'next_layer_layernorm' in name: continue + for n, p in module._parameters.items(): if p is not None: - p.data.copy_(module_weights[n.replace( - "weight", "scale")][:]) + p.data.copy_(module_weights[n][:]) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 5eadece6a99..1e8b310afe3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -214,7 +214,7 @@ def _check_configs(self): or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes." if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None: - assert self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports mxfp4 quantization with bias, swiglu_alpha, swiglu_beta and swiglu_limit." + assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants." def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( @@ -249,7 +249,7 @@ def create_weights(self): self._weights_created = True self._check_configs() - if (self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 + if (self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8) and not self.bias: self.w3_w1_bias = nn.Parameter(torch.zeros( @@ -324,6 +324,11 @@ def quantize_input(self, x, post_quant_comm: bool = True): self, 'fc31_act_scale') and self.fc31_act_scale is not None: x = x * self.fc31_act_scale + + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + if pad_size > 0: + x = torch.nn.functional.pad(x, (0, pad_size)) + x_row = x.shape[0] x, x_sf = torch.ops.trtllm.fp4_quantize( x, self.fc31_input_scale, self.scaling_vector_size, False, @@ -446,6 +451,8 @@ def run_moe( topk_ids=token_selected_experts, ) elif self.has_nvfp4: + intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ + -2] // 2 outputs = torch.ops.trtllm.fp4_block_scale_moe_runner( router_logits, @@ -454,8 +461,13 @@ def run_moe( x_sf.view(torch.float8_e4m3fn), self.w3_w1_weight, self.w3_w1_weight_scale.view(torch.float8_e4m3fn), + self.w3_w1_bias if self.bias else None, + self.swiglu_alpha, + self.swiglu_beta, + self.swiglu_limit, self.w2_weight, self.w2_weight_scale.view(torch.float8_e4m3fn), + self.w2_bias if self.bias else None, self.fc31_scale_c.data, self.fc31_alpha.data, self.fc2_alpha.data, @@ -463,7 +475,7 @@ def run_moe( top_k, n_group, topk_group, - self.intermediate_size_per_partition, + intermediate_size_per_partition_padded, self.slot_start, self.expert_size_per_partition, routed_scaling_factor, @@ -478,6 +490,11 @@ def run_moe( return outputs else: final_hidden_states = outputs[0] + # Slice output if it was padded + if final_hidden_states.shape[1] > self.hidden_size: + final_hidden_states = final_hidden_states[:, :self. + hidden_size].contiguous( + ) elif self.has_w4a16_mxfp4: assert x.dtype == torch.bfloat16 diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index c688263ba6a..f5c12f68ed1 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -218,13 +218,11 @@ def create_weights( # bias if module.bias: + # The shape might be padded so we use weight shape[:2] if w3_w1_bias_shape is None: - w3_w1_bias_shape = ( - module.expert_size_per_partition, - module.expand_intermediate_size_per_partition) + w3_w1_bias_shape = w3_w1_weight_shape[:2] if w2_bias_shape is None: - w2_bias_shape = (module.expert_size_per_partition, - module.hidden_size) + w2_bias_shape = w2_weight_shape[:2] bias_dtype = bias_dtype or module.dtype w3_w1_bias = nn.Parameter(torch.empty(w3_w1_bias_shape, dtype=bias_dtype), @@ -1731,7 +1729,8 @@ def create_weights(self, weight_vec_size, block_scales_dtype, block_scales_vec_size, - scaling_vector_size=16): + scaling_vector_size=16, + bias_dtype: Optional[torch.dtype] = None): module.scaling_vector_size = scaling_vector_size @@ -1780,7 +1779,8 @@ def create_weights(self, w3_w1_weight_shape=w3_w1_weight_shape, w2_weight_shape=w2_weight_shape, w3_w1_bias_shape=w3_w1_bias_shape, - w2_bias_shape=w2_bias_shape) + w2_bias_shape=w2_bias_shape, + bias_dtype=bias_dtype) self.setup_quant_scales(module) @@ -2300,7 +2300,7 @@ def load_expert_w3_w1_weight_scale_nvfp4( dst_w3_w1_weight_scale.copy_(w3_w1_weight_scale_interleaved) -class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): +class NVFP4TRTLLMGenFusedMoEMethodBase(NVFP4FusedMoEMethod): weight_dtype = float4_sf_dtype block_scales_dtype = torch.float8_e4m3fn @@ -2308,12 +2308,18 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): # This assumes the same input shape always results in the same permute indices _cache_permute_indices: Dict[torch.Size, torch.Tensor] = {} - def create_weights(self, module: torch.nn.Module): + def create_weights(self, + module: torch.nn.Module, + bias_dtype: Optional[torch.dtype] = None): weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 block_scales_vec_size = 1 - super().create_weights(module, self.weight_dtype, weight_vec_size, - self.block_scales_dtype, block_scales_vec_size) + super().create_weights(module, + self.weight_dtype, + weight_vec_size, + self.block_scales_dtype, + block_scales_vec_size, + bias_dtype=bias_dtype) fc31_scale_c = nn.Parameter(torch.ones(module.expert_size_per_partition, dtype=torch.float32), @@ -2565,7 +2571,340 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): }) -class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod): +class NVFP4TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethodBase): + weight_alignment = 32 + input_hidden_alignment = 32 + + def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int, + block_scales_vec_size: int): + + def round_up(x, alignment): + return (x + alignment - 1) // alignment * alignment + + # Compute padded sizes + intermediate_size_per_partition_padded = round_up( + module.intermediate_size_per_partition, self.weight_alignment) + w3_w1_hidden_size_padded = round_up(module.hidden_size, + self.input_hidden_alignment) + w2_hidden_size_padded = round_up(module.hidden_size, + self.weight_alignment) + + # Divide by 16 because we use int64 to pack 16 fp4 values + w3_w1_weight_shape = (module.expert_size_per_partition, + intermediate_size_per_partition_padded * + module.intermediate_size_expand_ratio, + w3_w1_hidden_size_padded // weight_vec_size) + w2_weight_shape = (module.expert_size_per_partition, + w2_hidden_size_padded, + intermediate_size_per_partition_padded // + weight_vec_size) + + w3_w1_weight_scale_shape = (module.expert_size_per_partition, + intermediate_size_per_partition_padded * + module.intermediate_size_expand_ratio, + w3_w1_hidden_size_padded // + module.scaling_vector_size // + block_scales_vec_size) + w2_weight_scale_shape = (module.expert_size_per_partition, + w2_hidden_size_padded, + intermediate_size_per_partition_padded // + module.scaling_vector_size // + block_scales_vec_size) + + if module.bias: + w3_w1_bias_shape = (module.expert_size_per_partition, + intermediate_size_per_partition_padded * + module.intermediate_size_expand_ratio) + w2_bias_shape = (module.expert_size_per_partition, + w2_hidden_size_padded) + else: + w3_w1_bias_shape = None + w2_bias_shape = None + + return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, + w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape) + + def create_weights(self, module: torch.nn.Module): + # Here we only enable padding for hidden_size > 1024 since there are small unit tests that expect no padding. + if module.hidden_size > 1024 and module.hidden_size % 256 != 0: + self.weight_alignment = 256 + # For now let's keep input alignment same as weight alignment. There are practical reasons that this might be a different value. + # See the comment in MXFP4WeightTRTLLMGenFusedMoEMethod for more details. + self.input_hidden_alignment = 256 + + super().create_weights(module, bias_dtype=torch.float32) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = tuple() + + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w3_w1_weight.device.type == "cuda" + dst_w3_w1_weight_gpu = dst_w3_w1_weight if dst_on_gpu else dst_w3_w1_weight.cuda( + ) + + alignment = _get_weight_alignment(self.weight_alignment, + module.scaling_vector_size, + module.tp_size, w1_weight.shape[0]) + if len(w1_weight.shape) == 2: + # Pad weights + # We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8. + assert w1_weight.dtype == torch.uint8 + w1_weight = maybe_pad_for_mxfp4(w1_weight, + self.input_hidden_alignment // 2, + alignment) + assert w3_weight.dtype == torch.uint8 + w3_weight = maybe_pad_for_mxfp4(w3_weight, + self.input_hidden_alignment // 2, + alignment) + else: + # Pad bias, TRTLLM backend expects float32 bias. + assert len(w1_weight.shape) == 1 + assert len(w3_weight.shape) == 1 + w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).float() + w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).float() + + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Keep weights in device buffer + dst_w3_weight, dst_w1_weight = dst_w3_w1_weight_gpu.chunk(2, dim=0) + dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype)) + dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype)) + + # Get permute indices + permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( + dst_w3_w1_weight_gpu, self._cache_permute_indices, epilogue_tile_m) + + # Shuffle the weight according to permute indices + processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix( + dst_w3_w1_weight_gpu, + permute_indices.to(dst_w3_w1_weight_gpu.device)) + + # Copy the result into device buffer + dst_w3_w1_weight_gpu.copy_(processed_w31_weight_shard.view( + dst_w3_w1_weight_gpu.dtype), + non_blocking=dst_on_gpu) + if not dst_on_gpu: + dst_w3_w1_weight.copy_(dst_w3_w1_weight_gpu) + + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w2_weight.device.type == "cuda" + dst_w2_weight_gpu = dst_w2_weight if dst_on_gpu else dst_w2_weight.cuda( + ) + + shard_w2_weight_dim = 2 * w2_weight.shape[1] if len( + w2_weight.shape) == 2 else w2_weight.shape[0] + alignment = _get_weight_alignment(self.weight_alignment, + module.scaling_vector_size, + module.tp_size, shard_w2_weight_dim) + if len(w2_weight.shape) == 2: + assert w2_weight.dtype == torch.uint8 + w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2, + self.weight_alignment) + else: + assert len(w2_weight.shape) == 1 + w2_weight = maybe_pad_for_mxfp4(w2_weight, self.weight_alignment) + + # Divide bias by tp_size as we shard along the hidden dimension. + # The bias is applied at each TP rank before the final accumulation. + w2_weight /= module.tp_size + + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Keep weights in device buffer + dst_w2_weight_gpu.copy_(w2_weight_shard.view(dst_w2_weight_gpu.dtype), + non_blocking=dst_on_gpu) + # Get permuted indices + permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( + dst_w2_weight_gpu, self._cache_permute_indices, epilogue_tile_m) + + # Shuffle the weight according to permute indices + processed_w2_weight = torch.ops.trtllm.shuffle_matrix( + dst_w2_weight_gpu, permute_indices.to(dst_w2_weight_gpu.device)) + + # Copy the result into device buffer + dst_w2_weight_gpu.copy_(processed_w2_weight.view( + dst_w2_weight_gpu.dtype), + non_blocking=dst_on_gpu) + + if not dst_on_gpu: + dst_w2_weight.copy_(dst_w2_weight_gpu) + + def load_expert_w3_w1_weight_scale_nvfp4( + self, + module: torch.nn.Module, + w1_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + dst_w3_w1_weight_scale: torch.Tensor, + num_elts_per_sf: int = 16): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w3_w1_weight_scale.device.type == "cuda" + dst_w3_w1_weight_scale_gpu = dst_w3_w1_weight_scale if dst_on_gpu else dst_w3_w1_weight_scale.cuda( + ) + + alignment = _get_weight_alignment(self.weight_alignment, + module.scaling_vector_size, + module.tp_size, + w3_weight_scale.shape[0]) + w1_weight_scale = maybe_pad_for_mxfp4( + w1_weight_scale, + self.input_hidden_alignment // module.scaling_vector_size, + alignment) + w3_weight_scale = maybe_pad_for_mxfp4( + w3_weight_scale, + self.input_hidden_alignment // module.scaling_vector_size, + alignment) + + w1_weight_scale = load_weight_shard(w1_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_scale = load_weight_shard(w3_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + # Keep weights in device buffer + dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale_gpu.chunk( + 2, dim=0) + dst_w3_weight_scale.copy_( + w3_weight_scale.view(dst_w3_weight_scale.dtype)) + dst_w1_weight_scale.copy_( + w1_weight_scale.view(dst_w1_weight_scale.dtype)) + + orig_shape = dst_w3_w1_weight_scale_gpu.shape + + # trtllm-gen specific block scales preprocessing logics + epilogue_tile_m = 128 # FIXME + + # Get permute indices + permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( + dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), + self._cache_permute_indices, + epilogue_tile_m, + num_elts_per_sf=num_elts_per_sf) + + # Shuffle the weight according to permute indices + w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix( + dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), permute_indices) + + # Assert should only be removed during debugging + assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed" + # Interleave the weight. + processed_w3_w1_weight_scale = torch.ops.trtllm.block_scale_interleave( + w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape)) + # Copy the result into device buffer + dst_w3_w1_weight_scale_gpu.copy_( + processed_w3_w1_weight_scale.view( + self.block_scales_dtype).reshape(orig_shape)) + + if not dst_on_gpu: + dst_w3_w1_weight_scale.copy_(dst_w3_w1_weight_scale_gpu) + + def load_expert_w2_weight_scale_nvfp4(self, + module: torch.nn.Module, + w2_weight_scale: torch.Tensor, + dst_w2_weight_scale: torch.Tensor, + num_elts_per_sf: int = 16): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w2_weight_scale.device.type == "cuda" + dst_w2_weight_scale_gpu = dst_w2_weight_scale if dst_on_gpu else dst_w2_weight_scale.cuda( + ) + + alignment = _get_weight_alignment(self.weight_alignment, + module.scaling_vector_size, + module.tp_size, + w2_weight_scale.shape[-1]) + w2_weight_scale = maybe_pad_for_mxfp4( + w2_weight_scale, alignment // module.scaling_vector_size, + self.weight_alignment) + + w2_weight_scale = load_weight_shard(w2_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + # Keep weights in device buffer + dst_w2_weight_scale_gpu.copy_( + w2_weight_scale.view(dst_w2_weight_scale_gpu.dtype)) + + orig_shape = dst_w2_weight_scale_gpu.shape + + # trtllm-gen specific block scales preprocessing logics + epilogue_tile_m = 128 # FIXME: read from kernel + + # Assert should only be removed during debugging + assert dst_w2_weight_scale_gpu.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" + + # Get permute indices + permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( + dst_w2_weight_scale_gpu.view(float4_sf_dtype), + self._cache_permute_indices, + epilogue_tile_m, + num_elts_per_sf=num_elts_per_sf) + + # Shuffle the weight according to permute indices + w_shuffled = torch.ops.trtllm.shuffle_matrix( + dst_w2_weight_scale_gpu.view(dtype=float4_sf_dtype), + permute_indices) + # Interleave the weight. + processed_w2_weight_scale = torch.ops.trtllm.block_scale_interleave( + w_shuffled) + # Copy the result into device buffer + dst_w2_weight_scale_gpu.copy_( + processed_w2_weight_scale.view( + self.block_scales_dtype).reshape(orig_shape)) + + if not dst_on_gpu: + dst_w2_weight_scale.copy_(dst_w2_weight_scale_gpu) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + super().load_quant_scales(module, weights) + + # Normalize biases to account for the global scale factors, + # matching the kernel's expectation (similar to test_moe.py logic). + if module.w3_w1_bias is not None: + # gemm1_bias * gemm1_scales_global * hidden_states_scale_global + module.w3_w1_bias.data.div_((module.fc31_alpha.data).view(-1, 1)) + + if module.w2_bias is not None: + # gemm2_bias * c_global_sf * gemm2_scales_global + module.w2_bias.data.div_((module.fc2_alpha.data).view(-1, 1)) + + if module.swiglu_beta is not None: + module.swiglu_beta.data.div_((module.fc31_alpha.data)) + + if module.swiglu_limit is not None: + module.swiglu_limit.data.div_((module.fc31_alpha.data)) + + +class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index f79a0eb0a6c..6a0c0a4fc55 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -270,6 +270,13 @@ GPT-OSS/20B-MXFP4: - quant_algo: W4A16_MXFP4 kv_cache_quant_algo: FP8 accuracy: 85.0 +GPT-OSS/20B-NVFP4: + - accuracy: 85.0 + - quant_algo: NVFP4 + accuracy: 85.0 + - quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + accuracy: 85.0 LGAI-EXAONE/EXAONE-4.0-32B: - accuracy: 88.36 ByteDance-Seed/Seed-OSS-36B-Instruct: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 7ae3da08783..cc30ec0be00 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4270,6 +4270,44 @@ def test_w4_2gpus(self, kv_cache_dtype, moe_backend, tp_size, pp_size, task.evaluate(llm, extra_evaluator_kwargs=self.extra_evaluator_kwargs) + @pytest.mark.skip_less_device(2) + @pytest.mark.skip_blackwell + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [ + (2, 1, 1, False, True, True), + (2, 1, 2, False, True, True), + (2, 1, 2, True, True, True), + ], + ids=["tp2", "ep2", "dp2"]) + def test_w4_2gpus_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, + cuda_graph, overlap_scheduler, mocker): + pytest.skip("Models not uploaded to CI") + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4, + dtype="auto") + + llm = LLM("./nvfp4ckpt", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + kv_cache_config=kv_cache_config, + max_seq_len=8192, + **pytorch_config, + enable_attention_dp=attention_dp, + moe_config=MoeConfig(backend="TRTLLM")) + + with llm: + model_name = "GPT-OSS/20B-NVFP4" + task = GSM8K(model_name) + mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192) + mocker.patch.dict(GSM8K.EVALUATE_KWARGS, + {"scores_filter": "exact_match,flexible-extract"}) + task.evaluate(llm, + extra_evaluator_kwargs=self.extra_evaluator_kwargs) + @pytest.mark.skip_less_device(4) @pytest.mark.parametrize( "kv_cache_dtype", diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 959736a4648..5a0b641f1b9 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -11,6 +11,7 @@ import pytest import torch import torch.nn as nn +import torch.nn.functional as F from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0, per_token_cast_to_fp8_e8m0) @@ -1398,6 +1399,53 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, and moe_backend in ["TRTLLM", "CUTLASS"] else "0" }) + run_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion) + + +@skip_pre_blackwell +@pytest.mark.parametrize("hidden_size, intermediate_size", [(2880, 2880)]) +@pytest.mark.parametrize("swiglu_alpha", [1, 0.1], ids=lambda v: f"alpha{v}") +@pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}") +@pytest.mark.parametrize("swiglu_limit", [float("inf"), 1], + ids=lambda v: f"limit{v}") +@pytest.mark.parametrize("enable_configurable_moe", [0, 1], + ids=lambda x: "" + if x == 0 else "enable_configurable_moe") +def test_fused_moe_nvfp4_gptoss_style(hidden_size, intermediate_size, + swiglu_alpha, swiglu_beta, swiglu_limit, + enable_configurable_moe, mocker): + mocker.patch.dict(os.environ, { + "ENABLE_CONFIGURABLE_MOE": + "1" if enable_configurable_moe == 1 else "0" + }) + + run_fused_moe_nvfp4(dtype=torch.bfloat16, + moe_backend="TRTLLM", + finalize_fusion=False, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts=32, + top_k=4, + seq_len=256, + gptoss_style=True, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit) + + +def run_fused_moe_nvfp4(dtype, + moe_backend, + finalize_fusion, + hidden_size=512, + intermediate_size=512, + num_experts=8, + top_k=2, + seq_len=4, + gptoss_style=False, + swiglu_alpha=None, + swiglu_beta=None, + swiglu_limit=None): + if moe_backend == "TRTLLM": if dtype == torch.float16: pytest.skip("TRTLLM NVFP4 MoE backend does not support float16 yet") @@ -1424,11 +1472,11 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, with torch.device(f"cuda:{mapping.rank}"): SCALING_VECTOR_SIZE = 16 - SEQ_LEN = 4 - HIDDEN_SIZE = 512 - INTERMEDIATE_SIZE = 512 - NUM_EXPERTS = 8 - TOP_K = 2 + SEQ_LEN = seq_len + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = intermediate_size + NUM_EXPERTS = num_experts + TOP_K = top_k routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) torch.manual_seed(0) torch.cuda.manual_seed(0) @@ -1455,24 +1503,38 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, device="cuda") * 0.05 w3_sf_global = (448 * 6) / w3_weight.abs().max().float() + if gptoss_style: + w1_bias = torch.randn(INTERMEDIATE_SIZE, + device='cuda', + dtype=torch.float) + w2_bias = torch.randn(HIDDEN_SIZE, + device='cuda', + dtype=torch.float) + w3_bias = torch.randn(INTERMEDIATE_SIZE, + device='cuda', + dtype=torch.float) + weights[f"{expert_id}.w1.bias"] = w1_bias + weights[f"{expert_id}.w2.bias"] = w2_bias + weights[f"{expert_id}.w3.bias"] = w3_bias + w3_w1_global = min( w1_sf_global, w3_sf_global) # w3 global and w1 global must be the same - w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize( - w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) - w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( - w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) + w1_weight_nvfp4, w1_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize( + w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False, False) + w1_sf_block_unswizzled = w1_sf_block_unswizzled.view( + INTERMEDIATE_SIZE, -1) - w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize( - w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False) - w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( - w2_sf_block.cpu().view(HIDDEN_SIZE, -1)) + w2_weight_nvfp4, w2_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize( + w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False, False) + w2_sf_block_unswizzled = w2_sf_block_unswizzled.view( + HIDDEN_SIZE, -1) - w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize( - w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) - w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( - w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) + w3_weight_nvfp4, w3_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize( + w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False, False) + w3_sf_block_unswizzled = w3_sf_block_unswizzled.view( + INTERMEDIATE_SIZE, -1) w1_input_scale = x_sf_global.cuda() w2_input_scale = x_sf_global.cuda() @@ -1497,6 +1559,23 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global + swiglu_alpha_tensor = None + swiglu_beta_tensor = None + swiglu_limit_tensor = None + if gptoss_style: + swiglu_alpha_tensor = torch.full((NUM_EXPERTS, ), + swiglu_alpha, + device='cuda', + dtype=torch.float) + swiglu_beta_tensor = torch.full((NUM_EXPERTS, ), + swiglu_beta, + device='cuda', + dtype=torch.float) + swiglu_limit_tensor = torch.full((NUM_EXPERTS, ), + swiglu_limit, + device='cuda', + dtype=torch.float) + quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) # Create pretrained_config with necessary parameters @@ -1514,6 +1593,10 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, quant_config=quant_config, moe_backend=moe_backend, moe_disable_finalize_fusion=not finalize_fusion), + bias=gptoss_style, + swiglu_alpha=swiglu_alpha_tensor, + swiglu_beta=swiglu_beta_tensor, + swiglu_limit=swiglu_limit_tensor, ) fused_moe.load_weights([weights]) fused_moe.post_load_weights() @@ -1526,7 +1609,11 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, dtype=dtype, - model_config=ModelConfig(quant_config=quant_config)) + model_config=ModelConfig(quant_config=quant_config), + bias=gptoss_style, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit) ref_fused_moe.load_weights([weights]) ref_fused_moe.cuda() @@ -1534,11 +1621,33 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, with torch.inference_mode(): ref_output = ref_fused_moe.forward(x, router_logits) - with torch.inference_mode(), autotune(): - fused_moe.forward(x, router_logits) + if not gptoss_style: + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + else: + # We skip autotune for gptoss style to reduce memory usage since the input shape is already quite large. + with torch.inference_mode(): + fused_moe.forward(x, router_logits) output = fused_moe.forward(x, router_logits) - torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15) + + if gptoss_style: + rtol = 0.1 + atol = 0.1 + percent = 0.95 + else: + rtol = 1e-2 + atol = 0.15 + percent = None + + if gptoss_style: + check_accuracy(output, + ref_output, + rtol=rtol, + atol=atol, + percent=percent) + else: + torch.testing.assert_close(output, ref_output, rtol=rtol, atol=atol) if not test_all_kernels: return @@ -1551,10 +1660,17 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion, for tactic in all_tactics: with AutoTuner.get().replay(tactic), torch.inference_mode(): output = fused_moe.forward(x, router_logits) - torch.testing.assert_close(output, - ref_output, - rtol=1e-2, - atol=0.15) + if gptoss_style: + check_accuracy(output, + ref_output, + rtol=rtol, + atol=atol, + percent=percent) + else: + torch.testing.assert_close(output, + ref_output, + rtol=rtol, + atol=atol) @skip_pre_blackwell @@ -2690,7 +2806,10 @@ def __init__(self, dtype: Optional[torch.dtype] = None, model_config: ModelConfig = ModelConfig(), use_cute_dsl_blockscaling_mm: bool = False, - bias=False): + bias=False, + swiglu_alpha: Optional[float] = None, + swiglu_beta: Optional[float] = None, + swiglu_limit: Optional[float] = None): super().__init__() self.num_experts = num_experts self.routing_method = routing_method @@ -2701,6 +2820,19 @@ def __init__(self, self.dtype = dtype self.quant_config = model_config.quant_config + def custom_swiglu(x): + gate, value = x.chunk(2, dim=-1) + if swiglu_limit is not None and swiglu_limit != float("inf"): + gate = gate.clamp(max=swiglu_limit) + value = value.clamp(min=-swiglu_limit, max=swiglu_limit) + + alpha = swiglu_alpha if swiglu_alpha is not None else 1.0 + gate_act = gate * torch.sigmoid(gate * alpha) + + beta = swiglu_beta if swiglu_beta is not None else 0.0 + + return gate_act * (value + beta) + self.experts = nn.ModuleList([ GatedMLP( hidden_size=self.hidden_size, @@ -2709,6 +2841,8 @@ def __init__(self, dtype=self.dtype, config=model_config, use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, + activation=custom_swiglu + if swiglu_alpha is not None else F.silu, ) for _ in range(self.num_experts) ]) diff --git a/tests/unittest/_torch/thop/serial/test_moe.py b/tests/unittest/_torch/thop/serial/test_moe.py index e252fc6047a..bffbed9d60b 100644 --- a/tests/unittest/_torch/thop/serial/test_moe.py +++ b/tests/unittest/_torch/thop/serial/test_moe.py @@ -547,11 +547,23 @@ def run_moe_reference_fp4(args): args.gemm2_weights, args.gemm2_scales, 1 / args.gemm2_scales_global, sf_vec_size).cuda() - args_dequant = moe_args_dequant( - args.num_tokens, args.num_experts, args.hidden_size, - args.intermediate_size, args.top_k, args.padding, hidden_states_dequant, - args.expert_logits, gemm1_weights_dequant, gemm2_weights_dequant, - args.permute_info, args.use_routing_scales_on_input) + args_dequant = moe_args_dequant(args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + gemm1_bias=args.gemm1_bias, + gemm1_alpha=args.gemm1_alpha, + gemm1_beta=args.gemm1_beta, + gemm1_clamp_limit=args.gemm1_clamp_limit, + gemm2_bias=args.gemm2_bias) return run_moe_dequant(args_dequant, "fp4"), args_dequant @@ -1157,6 +1169,44 @@ def test_no_autotune(self, num_tokens, hidden_size, intermediate_size, use_autotune=False, use_topk_as_input=use_topk_as_input) + @pytest.mark.parametrize("num_tokens", [1]) + @pytest.mark.parametrize("hidden_size", [512]) + @pytest.mark.parametrize("intermediate_size", [512]) + @pytest.mark.parametrize( + "routing_info", + [ + pytest.param( + { + "num_experts": 128, + "top_k": 4, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize + }, + id="RoutingGPTOSS") + ], + ) + @pytest.mark.parametrize("swiglu_alpha", [1, 0.1], + ids=lambda v: f"alpha{v}") + @pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}") + @pytest.mark.parametrize("swiglu_limit", [float("inf"), 1], + ids=lambda v: f"limit{v}") + def test_gptoss_style_nvfp4(self, num_tokens, hidden_size, + intermediate_size, routing_info, swiglu_alpha, + swiglu_beta, swiglu_limit): + + self.run_moe_fp4_test(num_tokens, + hidden_size, + intermediate_size, + routing_info, + use_autotune=False, + gptoss_style=True, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit) + @pytest.mark.parametrize("num_tokens", [1]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [1024]) @@ -1219,9 +1269,17 @@ def test_online_eplb288_topk_input(self, num_tokens, hidden_size, use_autotune=True, use_topk_as_input=True) - def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, - intermediate_size: int, routing_info: dict, - use_autotune: bool, use_topk_as_input: bool) -> None: + def run_moe_fp4_test(self, + num_tokens: int, + hidden_size: int, + intermediate_size: int, + routing_info: dict, + use_autotune: bool, + use_topk_as_input: bool = False, + gptoss_style: bool = False, + swiglu_alpha: float = None, + swiglu_beta: float = None, + swiglu_limit: float = None) -> None: torch.random.manual_seed(0) @@ -1289,6 +1347,39 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, device='cuda', dtype=torch.bfloat16) + gemm1_bias = None + gemm2_bias = None + swiglu_alpha_tensor = None + swiglu_beta_tensor = None + swiglu_limit_tensor = None + + if gptoss_style: + gemm1_bias = 50 * torch.randn(num_experts, + 2 * intermediate_size, + device='cuda', + dtype=torch.float) + gemm2_bias = 50 * torch.randn( + num_experts, hidden_size, device='cuda', dtype=torch.float) + + # waived due to missing kernel support for bias in nvfp4 + #gemm1_bias[:] = 0 + #gemm2_bias[:] = 0 + + swiglu_alpha_tensor = torch.full((num_experts, ), + swiglu_alpha, + device='cuda', + dtype=torch.float) + + swiglu_beta_tensor = torch.full((num_experts, ), + swiglu_beta, + device='cuda', + dtype=torch.float) + + swiglu_limit_tensor = torch.full((num_experts, ), + swiglu_limit, + device='cuda', + dtype=torch.float) + use_ue8m0 = False # Quantize hidden states. Produces scales for activations in 128x4 layout for ref impl. hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, hidden_states_scale_global = quant_fp4( @@ -1343,14 +1434,29 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, permute_info, scores = routing_reference_renormalize_naive( expert_logits, top_k, padding) - args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size, - top_k, padding, hidden_states_fp4_bytes, + args = moe_args(num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, - hidden_states_scale_global, scores, - gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, - gemm1_scales_global, gemm2_weights_fp4_bytes, - gemm2_scales_fp4_bytes, gemm2_scales_global, - permute_info, False) + hidden_states_scale_global, + scores, + gemm1_weights_fp4_bytes, + gemm1_scales_fp4_bytes, + gemm1_scales_global, + gemm2_weights_fp4_bytes, + gemm2_scales_fp4_bytes, + gemm2_scales_global, + permute_info, + False, + gemm1_bias=gemm1_bias, + gemm1_alpha=swiglu_alpha_tensor, + gemm1_beta=swiglu_beta_tensor, + gemm1_clamp_limit=swiglu_limit_tensor, + gemm2_bias=gemm2_bias) # # Run the reference implementations # @@ -1364,12 +1470,17 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, # Reorder rows of W1 and scales for fused gated activation gemm1_weights_fp4_interleaved = [] gemm1_scales_fp4_interleaved = [] + gemm1_bias_interleaved = [] for i in range(num_experts): gemm1_weights_fp4_interleaved.append( reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())) gemm1_scales_fp4_interleaved.append( reorder_rows_for_gated_act_gemm( gemm1_scales_linear_fp4[i].clone())) + if gemm1_bias is not None: + gemm1_bias_interleaved.append( + reorder_rows_for_gated_act_gemm( + gemm1_bias[i].clone().reshape(-1, 1))) # Stack weights and scales for all experts gemm1_weights_fp4_interleaved = torch.stack( @@ -1384,8 +1495,10 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp4_shuffled = [] gemm1_scales_fp4_shuffled = [] + gemm1_bias_shuffled = [] gemm2_weights_fp4_shuffled = [] gemm2_scales_fp4_shuffled = [] + gemm2_bias_shuffled = [] for i in range(num_experts): gemm1_weights_fp4_shuffled.append( shuffle_matrix_a( @@ -1395,6 +1508,10 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, shuffle_matrix_sf_a( gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m)) + if gemm1_bias is not None: + gemm1_bias_shuffled.append( + shuffle_matrix_a(gemm1_bias_interleaved[i], + epilogue_tile_m)) gemm2_weights_fp4_shuffled.append( shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8), @@ -1403,6 +1520,10 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, shuffle_matrix_sf_a( gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m)) + if gemm2_bias is not None: + gemm2_bias_shuffled.append( + shuffle_matrix_a(gemm2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) # Stack weights for all experts gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) @@ -1415,10 +1536,35 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, torch.float8_e4m3fn).reshape(num_experts, hidden_size, intermediate_size // 16) + if gemm1_bias is not None: + gemm1_bias_shuffled = torch.stack(gemm1_bias_shuffled).reshape( + num_experts, -1) + else: + gemm1_bias_shuffled = None + + if gemm2_bias is not None: + gemm2_bias_shuffled = torch.stack(gemm2_bias_shuffled).reshape( + num_experts, -1) + else: + gemm2_bias_shuffled = None + # # Run the TRT-LLM kernel # + if gptoss_style: + # NOTE: correct the beta and clamp to account for the global scale factor + # Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h + # for more details + swiglu_beta_tensor = swiglu_beta_tensor * args.gemm1_scales_global * args.hidden_states_scale_global + swiglu_limit_tensor = swiglu_limit_tensor * args.gemm1_scales_global * args.hidden_states_scale_global + # Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h + # for more details + gemm1_bias_shuffled = gemm1_bias_shuffled * args.gemm1_scales_global[:, + None] * args.hidden_states_scale_global + gemm2_bias_shuffled = gemm2_bias_shuffled * args_dequant.c_global_sf * args.gemm2_scales_global[:, + None] + # c_global_sf: fc2_input_scale scale_c_fc1 = args_dequant.c_global_sf * ( 1.0 / args.gemm1_scales_global) * (1.0 / @@ -1449,8 +1595,13 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, hidden_states_scale_linear_fp4, gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, + gemm1_bias_shuffled, # Bias + swiglu_alpha_tensor, # Alpha + swiglu_beta_tensor, # Beta + swiglu_limit_tensor, # Limit gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled, + gemm2_bias_shuffled, # Bias scale_c_fc1, scale_gate_fc1, scale_c_fc2, @@ -1469,11 +1620,20 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, torch.cuda.synchronize() output_dequant_actual = output[0].to(torch.float) + if gptoss_style: + atol = 0.2 + rtol = 0.2 + percent = 0.85 + else: + atol = 0.1 + rtol = 0.85 + percent = 0.925 + check_accuracy(output_dequant_reference, output_dequant_actual, - atol=0.1, - rtol=0.85, - percent=0.925) + atol=atol, + rtol=rtol, + percent=percent) def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int, intermediate_size: int, routing_info: dict,