diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 07fad25df49..d5146db47b3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -1124,6 +1124,26 @@ def post_load_weights(self): ) return self.backend.post_load_weights() + def process_weights_after_loading(self): + """ + Process weights after loading - delegated to backend + + """ + assert hasattr(self.backend, "process_weights_after_loading"), ( + f"Backend {self.backend.__class__.__name__} must implement process_weights_after_loading()" + ) + return self.backend.process_weights_after_loading() + + def pre_reload_weights(self): + """ + Pre reload weights - delegated to backend + + """ + assert hasattr(self.backend, "pre_reload_weights"), ( + f"Backend {self.backend.__class__.__name__} must implement pre_reload_weights()" + ) + return self.backend.pre_reload_weights() + # ========== Communication and Quantization Properties ========== @property diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 71e13e1324b..565b3379dc1 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -1,3 +1,4 @@ +import inspect import os from functools import cached_property from typing import Dict, List, Optional, Tuple, Union @@ -862,16 +863,22 @@ def load_weights(self, assert len(weights) == 1 weights = weights[0] - if not isinstance(self.quant_method, UnquantizedFusedMoEMethod): - assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now" - self.quant_method.load_weights(self, weights, - self.weight_loading_mode) - else: - self.quant_method.load_weights( - self, - weights, - self.weight_loading_mode, - allow_partial_loading=allow_partial_loading) + kargs = {} + if "allow_partial_loading" in inspect.getfullargspec( + self.quant_method.load_weights).args: + kargs["allow_partial_loading"] = allow_partial_loading + self.quant_method.load_weights(self, weights, self.weight_loading_mode, + **kargs) def post_load_weights(self): self.quant_method.post_load_weights(self) + + def process_weights_after_loading(self): + if hasattr(self.quant_method, 'process_weights_after_loading'): + self.quant_method.process_weights_after_loading(self) + + def pre_reload_weights(self): + assert hasattr( + self.quant_method, 'pre_reload_weights' + ), "pre_reload_weights is not supported for this quant method" + self.quant_method.pre_reload_weights(self) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py index 22d95411bc1..e817d317d21 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py @@ -1401,3 +1401,13 @@ def load_weights(self, def post_load_weights(self): self.quant_method.post_load_weights(self) + + def process_weights_after_loading(self): + if hasattr(self.quant_method, 'process_weights_after_loading'): + self.quant_method.process_weights_after_loading(self) + + def pre_reload_weights(self): + assert hasattr( + self.quant_method, 'pre_reload_weights' + ), "pre_reload_weights is not supported for this quant method" + self.quant_method.pre_reload_weights(self) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index df0a419b387..0ae4d929070 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -1,3 +1,4 @@ +import inspect import os from functools import cached_property from typing import Dict, List, Optional, Union @@ -21,9 +22,9 @@ # isort: off from .quantization import ( DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod, - NVFP4TRTLLMGenFusedMoEMethod, UnquantizedFusedMoEMethod, - W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, - W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod) + NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, + W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, + W4A16MXFP4TRTLLMGenFusedMoEMethod) # isort: on from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod @@ -273,20 +274,26 @@ def load_weights(self, assert len(weights) == 1 weights = weights[0] - if not isinstance(self.quant_method, UnquantizedFusedMoEMethod): - assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now" - self.quant_method.load_weights(self, weights, - self.weight_loading_mode) - else: - self.quant_method.load_weights( - self, - weights, - self.weight_loading_mode, - allow_partial_loading=allow_partial_loading) + kargs = {} + if "allow_partial_loading" in inspect.getfullargspec( + self.quant_method.load_weights).args: + kargs["allow_partial_loading"] = allow_partial_loading + self.quant_method.load_weights(self, weights, self.weight_loading_mode, + **kargs) def post_load_weights(self): self.quant_method.post_load_weights(self) + def process_weights_after_loading(self): + if hasattr(self.quant_method, 'process_weights_after_loading'): + self.quant_method.process_weights_after_loading(self) + + def pre_reload_weights(self): + assert hasattr( + self.quant_method, 'pre_reload_weights' + ), "pre_reload_weights is not supported for this quant method" + self.quant_method.pre_reload_weights(self) + def quantize_input(self, x, post_quant_comm: bool = True): """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation. diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index f0a605cc077..160fb7905fc 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -1,3 +1,4 @@ +import inspect import os from typing import Dict, List, Optional, Tuple, Union @@ -935,20 +936,26 @@ def load_weights(self, assert len(weights) == 1 weights = weights[0] - if not isinstance(self.quant_method, UnquantizedFusedMoEMethod): - assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now" - self.quant_method.load_weights(self, weights, - self.weight_loading_mode) - else: - self.quant_method.load_weights( - self, - weights, - self.weight_loading_mode, - allow_partial_loading=allow_partial_loading) + kargs = {} + if "allow_partial_loading" in inspect.getfullargspec( + self.quant_method.load_weights).args: + kargs["allow_partial_loading"] = allow_partial_loading + self.quant_method.load_weights(self, weights, self.weight_loading_mode, + **kargs) def post_load_weights(self): self.quant_method.post_load_weights(self) + def process_weights_after_loading(self): + if hasattr(self.quant_method, 'process_weights_after_loading'): + self.quant_method.process_weights_after_loading(self) + + def pre_reload_weights(self): + assert hasattr( + self.quant_method, 'pre_reload_weights' + ), "pre_reload_weights is not supported for this quant method" + self.quant_method.pre_reload_weights(self) + def forward_fake( self, x: Union[torch.Tensor, Fp4QuantizedTensor], diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index ead2510b0d3..cfdd160ba09 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -521,12 +521,20 @@ def create_weights(self): raise NotImplementedError @abstractmethod - def load_weights(self, weights: List[Dict]): + def load_weights(self, + weights: List[Dict], + allow_partial_loading: bool = False): raise NotImplementedError + def process_weights_after_loading(self): + pass + def post_load_weights(self): pass + def pre_reload_weights(self): + pass + @abstractmethod def quantize_input( self, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 18d00f75aef..f7ab7d10268 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -18,7 +18,8 @@ from tensorrt_llm.quantization.utils.fp8_utils import ( resmooth_to_fp8_e8m0, transform_sf_into_required_layout) -from ...utils import swizzle_sf, unswizzle_sf +from ...utils import (replace_parameter_and_save_metadata, swizzle_sf, + unswizzle_sf) from ..linear import TensorParallelMode, load_weight_shard from .interface import MoEWeightLoadingMode @@ -236,6 +237,8 @@ def create_weights( module.w3_w1_bias = None module.w2_bias = None + module.rebuild_tensor_metadata = {} + def load_expert_weights_to_dst( self, module: torch.nn.Module, @@ -330,6 +333,12 @@ def load_weights(self, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode, allow_partial_loading: bool = False): + if allow_partial_loading: + assert isinstance( + self, (UnquantizedFusedMoEMethod, FP8QDQFusedMoEMethod, + DeepSeekFP8BlockScalesFusedMoEMethod, + DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm) + ), "Partial loading is only supported for unquantized and FP8 models" additional_kargs = {} if "allow_partial_loading" in inspect.getfullargspec( self.load_expert_weights_to_dst).args: @@ -402,6 +411,9 @@ def load_weights(self, local_shared_w2_bias_tensors if module.bias else None, **additional_kargs) + if not allow_partial_loading: + self.process_weights_after_loading(module) + def post_load_weights(self, module: torch.nn.Module): if self.need_load_shared_weights(module): weight_fns = { @@ -432,6 +444,9 @@ def post_load_weights(self, module: torch.nn.Module): def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]): pass + def process_weights_after_loading(self, module: torch.nn.Module): + pass + @abstractmethod def setup_quant_scales(self, module: torch.nn.Module): raise NotImplementedError @@ -473,21 +488,19 @@ def load_expert_w3_w1_weight(self, TensorParallelMode.COLUMN, device=device) if w3_weight is not None else None - src_w3_size_shard = w3_weight_shard.shape[ - 0] if w3_weight_shard is not None else 0 - src_w1_size_shard = w1_weight_shard.shape[ - 0] if w1_weight_shard is not None else 0 - if w1_weight is not None: - dst_w1_weight = dst_w3_w1_weight.narrow(dim=0, - start=src_w3_size_shard, - length=src_w1_size_shard) - dst_w1_weight.copy_(w1_weight_shard.contiguous().view( - dst_w3_w1_weight.dtype), - non_blocking=True) - if w3_weight is not None: - dst_w3_weight = dst_w3_w1_weight.narrow(dim=0, - start=0, - length=src_w3_size_shard) + dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0) + if w1_weight_shard is not None and w1_weight_shard.shape[0] != 0: + w1_weight_shard_viewed = w1_weight_shard.contiguous().view( + dst_w3_w1_weight.dtype) + if w1_weight_shard_viewed.shape[0] == dst_w3_w1_weight.shape[0]: + # w3_weight (gate_proj) should be empty for Nemotron-H MoE model. + dst_w3_w1_weight.copy_(w1_weight_shard_viewed, + non_blocking=True) + elif w1_weight_shard_viewed.shape[0] == dst_w1_weight.shape[0]: + dst_w1_weight.copy_(w1_weight_shard_viewed, non_blocking=True) + else: + raise ValueError("Shape mismatch!") + if w3_weight_shard is not None and w3_weight_shard.shape[0] != 0: dst_w3_weight.copy_(w3_weight_shard.contiguous().view( dst_w3_w1_weight.dtype), non_blocking=True) @@ -516,6 +529,16 @@ def load_expert_w2_weight(self, dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), non_blocking=True) + def pre_reload_weights(self, module: torch.nn.Module): + for param_name, metadata in module.rebuild_tensor_metadata.items(): + logger.warning( + f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs." + ) + param = torch.nn.Parameter(torch.empty_like(metadata, + device="cuda"), + requires_grad=False) + module.register_parameter(param_name, param) + class UnquantizedFusedMoEMethod(FusedMoEMethodBase): @@ -539,8 +562,12 @@ def setup_quant_scales(self, module: torch.nn.Module): def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale, dst_fc31_input_scale: torch.Tensor): - dst_fc31_input_scale.copy_( - max(w1_input_scale[...].reshape([]), w3_input_scale[...].reshape([]))) + if w1_input_scale is not None and w1_input_scale.numel() != 0: + w1_input_scale = w1_input_scale[...].reshape([]) + dst_fc31_input_scale[0].copy_(w1_input_scale) + if w3_input_scale is not None and w3_input_scale.numel() != 0: + w3_input_scale = w3_input_scale[...].reshape([]) + dst_fc31_input_scale[1].copy_(w3_input_scale) def load_expert_fc2_input_scale_fp8_qdq(w2_input_scale, @@ -549,35 +576,45 @@ def load_expert_fc2_input_scale_fp8_qdq(w2_input_scale, def load_activation_scales_fp8_qdq(module: torch.nn.Module, weights: Dict): - tmp_fc31_input_scale = torch.empty(module.num_experts, dtype=torch.float32) - tmp_fc2_input_scale = torch.empty(module.num_experts, dtype=torch.float32) + if not hasattr(module, 'tmp_fc31_input_scale'): + module.tmp_fc31_input_scale = torch.empty( + (module.num_experts, 2), + dtype=torch.float32, + device=module.fc31_dequant.device) + tmp_fc31_input_scale = module.tmp_fc31_input_scale + if not hasattr(module, 'tmp_fc2_input_scale'): + module.tmp_fc2_input_scale = torch.empty( + module.num_experts, + dtype=torch.float32, + device=module.fc2_dequant.device) + tmp_fc2_input_scale = module.tmp_fc2_input_scale for expert_id in range(module.num_experts): if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_input_scale = weights[f"{expert_id}.w1.input_scale"] - w3_input_scale = weights[f"{expert_id}.w3.input_scale"] - w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + w1_input_scale = weights[ + f"{expert_id}.w1.input_scale"] if f"{expert_id}.w1.input_scale" in weights else None + w3_input_scale = weights[ + f"{expert_id}.w3.input_scale"] if f"{expert_id}.w3.input_scale" in weights else None + w2_input_scale = weights[ + f"{expert_id}.w2.input_scale"] if f"{expert_id}.w2.input_scale" in weights else None elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: - w1_input_scale = weights[f"gate_up_proj_input_scale"] - w3_input_scale = weights[f"gate_up_proj_input_scale"] - w2_input_scale = weights[f"down_proj_input_scale"] + w1_input_scale = weights[ + f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None + w3_input_scale = weights[ + f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None + w2_input_scale = weights[ + f"down_proj_input_scale"] if f"down_proj_input_scale" in weights else None else: raise NotImplementedError( f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" ) - load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale, - tmp_fc31_input_scale[expert_id]) - - load_expert_fc2_input_scale_fp8_qdq(w2_input_scale, - tmp_fc2_input_scale[expert_id]) - - # max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales. - # It's used to quantize fc31 input inside the MOE op - max_fc31_input_scale = tmp_fc31_input_scale.max() - # max_fc2_input_scale is the maximum of all w2 input scales. - max_fc2_input_scale = tmp_fc2_input_scale.max() + if w1_input_scale is not None or w3_input_scale is not None: + load_expert_fc31_input_scale_fp8_qdq( + w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id]) - return max_fc31_input_scale, max_fc2_input_scale + if w2_input_scale is not None: + load_expert_fc2_input_scale_fp8_qdq(w2_input_scale, + tmp_fc2_input_scale[expert_id]) def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module, @@ -654,9 +691,12 @@ def setup_quant_scales(self, module: torch.nn.Module): def load_expert_w3_w1_weight_scale_fp8_qdq( self, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale: torch.Tensor): - w1_weight_scale = w1_weight_scale[...].reshape([]) - w3_weight_scale = w3_weight_scale[...].reshape([]) - dst_w3_w1_weight_scale.copy_(max(w1_weight_scale, w3_weight_scale)) + if w1_weight_scale is not None and w1_weight_scale.numel() != 0: + w1_weight_scale = w1_weight_scale[...].reshape([]) + dst_w3_w1_weight_scale[0].copy_(w1_weight_scale) + if w3_weight_scale is not None and w3_weight_scale.numel() != 0: + w3_weight_scale = w3_weight_scale[...].reshape([]) + dst_w3_w1_weight_scale[1].copy_(w3_weight_scale) def load_expert_w2_weight_scale_fp8(self, w2_weight_scale, dst_w2_weight_scale: torch.Tensor): @@ -664,25 +704,38 @@ def load_expert_w2_weight_scale_fp8(self, w2_weight_scale, def load_quant_scales(self, module: torch.nn.Module, weights: Dict): # Step1: Load input scales. - max_fc31_input_scale, max_fc2_input_scale = load_activation_scales_fp8_qdq( - module, weights) - - # Step2: Load weight scales and requantize w3_w1_weight. - tmp_w3_w1_weight_scale = torch.empty(module.expert_size_per_partition, - dtype=torch.float32) - tmp_w2_weight_scale = torch.empty(module.expert_size_per_partition, - dtype=torch.float32) + load_activation_scales_fp8_qdq(module, weights) + + # Step2: Load weight scales + if not hasattr(module, 'tmp_w3_w1_weight_scale'): + module.tmp_w3_w1_weight_scale = torch.empty( + (module.expert_size_per_partition, 2), + dtype=torch.float32, + device=module.fc31_dequant.device) + if not hasattr(module, 'tmp_w2_weight_scale'): + module.tmp_w2_weight_scale = torch.empty( + module.expert_size_per_partition, + dtype=torch.float32, + device=module.fc2_dequant.device) + tmp_w3_w1_weight_scale = module.tmp_w3_w1_weight_scale + tmp_w2_weight_scale = module.tmp_w2_weight_scale for local_slot_id, expert_id in enumerate( module.initial_local_expert_ids): if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] - w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] - w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + w1_weight_scale = weights[ + f"{expert_id}.w1.weight_scale"] if f"{expert_id}.w1.weight_scale" in weights else None + w3_weight_scale = weights[ + f"{expert_id}.w3.weight_scale"] if f"{expert_id}.w3.weight_scale" in weights else None + w2_weight_scale = weights[ + f"{expert_id}.w2.weight_scale"] if f"{expert_id}.w2.weight_scale" in weights else None elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: - w1_weight_scale = weights[f"gate_up_proj_weight_scale"] - w3_weight_scale = weights[f"gate_up_proj_weight_scale"] - w2_weight_scale = weights[f"down_proj_weight_scale"] + w1_weight_scale = weights[ + f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None + w3_weight_scale = weights[ + f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None + w2_weight_scale = weights[ + f"down_proj_weight_scale"] if f"down_proj_weight_scale" in weights else None else: raise NotImplementedError( f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" @@ -690,24 +743,45 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): expert_idx = local_slot_id - self.load_expert_w3_w1_weight_scale_fp8_qdq( - w1_weight_scale, w3_weight_scale, - tmp_w3_w1_weight_scale[expert_idx]) - + if w1_weight_scale is not None or w3_weight_scale is not None: + self.load_expert_w3_w1_weight_scale_fp8_qdq( + w1_weight_scale, w3_weight_scale, + tmp_w3_w1_weight_scale[expert_idx]) + + if w2_weight_scale is not None: + self.load_expert_w2_weight_scale_fp8( + w2_weight_scale, tmp_w2_weight_scale[expert_idx]) + + def process_weights_after_loading(self, module: torch.nn.Module): + # max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales. + # It's used to quantize fc31 input inside the MOE op + max_fc31_input_scale = module.tmp_fc31_input_scale.max() + # max_fc2_input_scale is the maximum of all w2 input scales. + max_fc2_input_scale = module.tmp_fc2_input_scale.max() + # Requantize w3_w1_weight + for local_slot_id, _ in enumerate(module.initial_local_expert_ids): + expert_idx = local_slot_id requantize_expert_w3_w1_weight_fp8_qdq( - module, w1_weight_scale, w3_weight_scale, + module, module.tmp_w3_w1_weight_scale[expert_idx][0], + module.tmp_w3_w1_weight_scale[expert_idx][1], module.w3_w1_weight.data[expert_idx]) - self.load_expert_w2_weight_scale_fp8( - w2_weight_scale, tmp_w2_weight_scale[expert_idx]) - - # Step3: calculate and store final loaded weights - module.fc31_dequant.data.copy_(tmp_w3_w1_weight_scale * + # Calculate and store final loaded weights + max_w3_w1_weight_scale = module.tmp_w3_w1_weight_scale.max(dim=1).values + module.fc31_dequant.data.copy_(max_w3_w1_weight_scale * max_fc31_input_scale) module.fc2_quant.data.copy_(max_fc2_input_scale.reciprocal()) - module.fc2_dequant.data.copy_(tmp_w2_weight_scale * max_fc2_input_scale) + module.fc2_dequant.data.copy_(module.tmp_w2_weight_scale * + max_fc2_input_scale) module.fc31_input_dequant.data.copy_(max_fc31_input_scale) + self.setup_quant_scales(module) + + delattr(module, 'tmp_w3_w1_weight_scale') + delattr(module, 'tmp_w2_weight_scale') + delattr(module, 'tmp_fc31_input_scale') + delattr(module, 'tmp_fc2_input_scale') + def post_load_weights(self, module): super().post_load_weights(module) @@ -733,11 +807,15 @@ def _maybe_padding_weights(tensor: torch.Tensor, row_alignment: int, module.w2_weight, cutlass_fp8_row_alignment, cutlass_fp8_row_alignment) if is_padded_w3_w1_weight: - module.w3_w1_weight = nn.Parameter(padded_w3_w1_weight, - requires_grad=False) + replace_parameter_and_save_metadata( + module, "w3_w1_weight", + nn.Parameter(padded_w3_w1_weight, requires_grad=False), + module.rebuild_tensor_metadata) if is_padded_w2_weight: - module.w2_weight = nn.Parameter(padded_w2_weight, - requires_grad=False) + replace_parameter_and_save_metadata( + module, "w2_weight", + nn.Parameter(padded_w2_weight, requires_grad=False), + module.rebuild_tensor_metadata) class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): @@ -778,9 +856,13 @@ def create_weights(self, module: torch.nn.Module): self.setup_quant_scales(module) - def load_weights(self, module: torch.nn.Module, weights: List[Dict], - weight_loading_mode: MoEWeightLoadingMode): - super().load_weights(module, weights, weight_loading_mode) + def load_weights(self, + module: torch.nn.Module, + weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + allow_partial_loading: bool = False): + super().load_weights(module, weights, weight_loading_mode, + allow_partial_loading) def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales( @@ -795,42 +877,53 @@ def load_expert_all_weight_scale_fp8_block_scale( for local_slot_id, expert_id in enumerate(load_expert_ids): if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: w3_scale = weights['gate_up_proj_weight_scale'][ - expert_id].transpose(0, 1).contiguous() - w1_scale = None + expert_id].transpose(0, 1).contiguous( + ) if "gate_up_proj_weight_scale" in weights else None w2_scale = weights['down_proj_weight_scale'][ - expert_id].transpose(0, 1).contiguous() + expert_id].transpose(0, 1).contiguous( + ) if "down_proj_weight_scale" in weights else None + w3_w1_scale_shard = load_weight_shard(w3_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard) elif module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w3_scale = weights[f"{expert_id}.w3.weight_scale_inv"] - w1_scale = weights[f"{expert_id}.w1.weight_scale_inv"] - w2_scale = weights[f"{expert_id}.w2.weight_scale_inv"] + w3_scale = weights[ + f"{expert_id}.w3.weight_scale_inv"] if f"{expert_id}.w3.weight_scale_inv" in weights else None + w1_scale = weights[ + f"{expert_id}.w1.weight_scale_inv"] if f"{expert_id}.w1.weight_scale_inv" in weights else None + w2_scale = weights[ + f"{expert_id}.w2.weight_scale_inv"] if f"{expert_id}.w2.weight_scale_inv" in weights else None + dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale[ + local_slot_id].chunk(2, dim=0) + if w1_scale is not None: + w1_scale_shard = load_weight_shard( + w1_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + dst_w1_weight_scale.copy_(w1_scale_shard) + if w3_scale is not None: + w3_scale_shard = load_weight_shard( + w3_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + dst_w3_weight_scale.copy_(w3_scale_shard) else: raise NotImplementedError( f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" ) - - w3_w1_scale_shard = load_weight_shard(w3_scale, - module.tp_size, - module.tp_rank, - TensorParallelMode.COLUMN, - device=device) - - if w1_scale is not None: - w1_scale_shard = load_weight_shard(w1_scale, + if w2_scale is not None: + w2_scale_shard = load_weight_shard(w2_scale, module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN, + TensorParallelMode.ROW, device=device) - w3_w1_scale_shard = torch.cat( - [w3_w1_scale_shard, w1_scale_shard], dim=-2) - - dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard) - - w2_scale_shard = load_weight_shard(w2_scale, - module.tp_size, - module.tp_rank, - TensorParallelMode.ROW, - device=device) - dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard) + dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard) def load_quant_scales(self, module: torch.nn.Module, weights: Dict): self.load_expert_all_weight_scale_fp8_block_scale( @@ -843,16 +936,30 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): if self.need_load_shared_weights(module): local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( ) - local_shared_w3_w1_scale_tensors = torch.empty( - (len(local_shared_load_expert_ids), ) + - module.w3_w1_weight_scaling_factor.data.shape[1:], - dtype=module.w3_w1_weight_scaling_factor.data.dtype, - device='cpu') - local_shared_w2_scale_tensors = torch.empty( - (len(local_shared_load_expert_ids), ) + - module.w2_weight_scaling_factor.data.shape[1:], - dtype=module.w2_weight_scaling_factor.data.dtype, - device='cpu') + if getattr(module, 'local_shared_w3_w1_scale_tensors', + None) is not None: + local_shared_w3_w1_scale_tensors = getattr( + module, 'local_shared_w3_w1_scale_tensors') + else: + local_shared_w3_w1_scale_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w3_w1_weight_scaling_factor.data.shape[1:], + dtype=module.w3_w1_weight_scaling_factor.data.dtype, + device='cpu') + setattr(module, 'local_shared_w3_w1_scale_tensors', + local_shared_w3_w1_scale_tensors) + if getattr(module, 'local_shared_w2_scale_tensors', + None) is not None: + local_shared_w2_scale_tensors = getattr( + module, 'local_shared_w2_scale_tensors') + else: + local_shared_w2_scale_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w2_weight_scaling_factor.data.shape[1:], + dtype=module.w2_weight_scaling_factor.data.dtype, + device='cpu') + setattr(module, 'local_shared_w2_scale_tensors', + local_shared_w2_scale_tensors) self.load_expert_all_weight_scale_fp8_block_scale( module, weights, @@ -860,19 +967,32 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): local_shared_w3_w1_scale_tensors, local_shared_w2_scale_tensors, device=torch.device("cpu")) - module.register_all_parameter_slot_and_to_fix_weight_fns({ - 'w3_w1_weight_scaling_factor': - local_shared_w3_w1_scale_tensors, - 'w2_weight_scaling_factor': - local_shared_w2_scale_tensors, - }) + + def post_load_weights(self, module: torch.nn.Module): + if self.need_load_shared_weights(module): + weight_fns = {} + if hasattr(module, 'local_shared_w3_w1_scale_tensors'): + weight_fns['w3_w1_weight_scaling_factor'] = getattr( + module, 'local_shared_w3_w1_scale_tensors') + delattr(module, 'local_shared_w3_w1_scale_tensors') + if hasattr(module, 'local_shared_w2_scale_tensors'): + weight_fns['w2_weight_scaling_factor'] = getattr( + module, 'local_shared_w2_scale_tensors') + delattr(module, 'local_shared_w2_scale_tensors') + if weight_fns: + module.register_all_parameter_slot_and_to_fix_weight_fns( + weight_fns) + super().post_load_weights(module) class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( DeepSeekFP8BlockScalesFusedMoEMethod): - def load_weights(self, module: torch.nn.Module, weights: List[Dict], - weight_loading_mode: MoEWeightLoadingMode): + def load_weights(self, + module: torch.nn.Module, + weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + allow_partial_loading: bool = False): if is_sm_100f(): expert_ids = set(module.initial_local_expert_ids) if self.need_load_shared_weights(module): @@ -888,7 +1008,8 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], scale = weights[name][:] weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( weight, scale) - super().load_weights(module, weights, weight_loading_mode) + super().load_weights(module, weights, weight_loading_mode, + allow_partial_loading) def post_load_weights(self, module: torch.nn.Module): super().post_load_weights(module) @@ -900,8 +1021,12 @@ def post_load_weights(self, module: torch.nn.Module): recipe=(1, 128, 128), num_groups=module.w3_w1_weight.shape[0], is_sfa=False) - module.w3_w1_weight_scaling_factor = nn.Parameter( + transformed_w3_w1_weight_scaling_factor = nn.Parameter( transfromed_w3_w1_scale, requires_grad=False) + replace_parameter_and_save_metadata( + module, "w3_w1_weight_scaling_factor", + transformed_w3_w1_weight_scaling_factor, + module.rebuild_tensor_metadata) transfromed_w2_scale = transform_sf_into_required_layout( module.quant_scales[1], mn=module.w2_weight.shape[1], @@ -909,8 +1034,12 @@ def post_load_weights(self, module: torch.nn.Module): recipe=(1, 128, 128), num_groups=module.w3_w1_weight.shape[0], is_sfa=False) - module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, - requires_grad=False) + transformed_w2_weight_scaling_factor = nn.Parameter( + transfromed_w2_scale, requires_grad=False) + replace_parameter_and_save_metadata( + module, "w2_weight_scaling_factor", + transformed_w2_weight_scaling_factor, + module.rebuild_tensor_metadata) self.setup_quant_scales(module) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 44daa25eb3c..2f147985e36 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -28,7 +28,8 @@ from ..._utils import get_sm_version, is_sm_100f from ...models.modeling_utils import QuantConfig -from ..utils import Fp4QuantizedTensor, get_model_extra_attrs, unswizzle_sf +from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs, + replace_parameter_and_save_metadata, unswizzle_sf) class WeightMode(str, enum.Enum): @@ -39,6 +40,40 @@ class WeightMode(str, enum.Enum): # weight of a fused gate and up linear layer FUSED_GATE_UP_LINEAR = 'fused_gate_up_linear' + @property + def int_value(self) -> int: + _INT_MAP = { + WeightMode.VANILLA: 1, + WeightMode.FUSED_GATE_UP_LINEAR: 2, + WeightMode.FUSED_QKV_LINEAR: 3, + } + return _INT_MAP[self] + + @property + def shard_keys(self) -> list[str] | None: + _SHARD_KEYS_MAP = { + WeightMode.VANILLA: None, + WeightMode.FUSED_GATE_UP_LINEAR: ['gate', 'up'], + WeightMode.FUSED_QKV_LINEAR: ['q', 'k', 'v'], + } + return _SHARD_KEYS_MAP[self] + + @property + def shard_key_to_index(self) -> dict[str, int] | None: + _SHARD_KEY_TO_INDEX_MAP = { + WeightMode.VANILLA: None, + WeightMode.FUSED_GATE_UP_LINEAR: { + 'gate': 0, + 'up': 1 + }, + WeightMode.FUSED_QKV_LINEAR: { + 'q': 0, + 'k': 1, + 'v': 2 + }, + } + return _SHARD_KEY_TO_INDEX_MAP[self] + @dataclass(kw_only=True) class WeightsLoadingConfig: @@ -327,6 +362,9 @@ def load_weights(self, else: raise ValueError(f'unsupported weight mode: {weight_mode}') + if not allow_partial_loading: + self.process_weights_after_loading(module) + def post_load_weights(self, module: Linear): pass @@ -367,6 +405,36 @@ def load_weights_fused_gate_up_linear( """ raise NotImplementedError + def process_weights_after_loading(self, module: Linear): + """ + Process quantization weights and scales after loading weights. + """ + weight_mode = module.weights_loading_config.weight_mode + if weight_mode == WeightMode.VANILLA: + self.process_weights_after_loading_vanilla(module) + elif weight_mode == WeightMode.FUSED_QKV_LINEAR: + self.process_weights_after_loading_fused_qkv_linear(module) + elif weight_mode == WeightMode.FUSED_GATE_UP_LINEAR: + self.process_weights_after_loading_fused_gate_up_linear(module) + else: + raise ValueError(f'unsupported weight mode: {weight_mode}') + + def process_weights_after_loading_vanilla(self, module: Linear): + """ + Process quantization weights and scales after loading weights for vanilla linear layer. + """ + + def process_weights_after_loading_fused_qkv_linear(self, module: Linear): + """ + Process quantization weights and scales after loading weights for fused QKV linear layer. + """ + + def process_weights_after_loading_fused_gate_up_linear( + self, module: Linear): + """ + Process quantization weights and scales after loading weights for fused gate up linear layer. + """ + class UnquantizedLinearMethod(LinearMethodBase): @@ -382,6 +450,8 @@ def create_weights(self, module: Linear, in_features: int, else: module.register_parameter("bias", None) + module.rebuild_tensor_metadata = {} + def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor]): if module.use_custom_cublas_mm: @@ -440,8 +510,17 @@ def load_weights_fused_gate_up_linear( copy_weight_shard(module.weight, weight, shard_offset, shard_size) + def pre_reload_weights(self, module: Linear): + for param_name, metadata in module.rebuild_tensor_metadata.items(): + logger.warning( + f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs." + ) + param = Parameter(torch.empty_like(metadata, device="cuda"), + requires_grad=False) + module.register_parameter(param_name, param) + -class FP8QDQLinearMethod(LinearMethodBase): +class FP8QDQLinearMethod(UnquantizedLinearMethod): def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): @@ -468,6 +547,8 @@ def create_weights(self, module: Linear, in_features: int, else: module.register_parameter("bias", None) + module.rebuild_tensor_metadata = {} + def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor]): cur_input_scale = module.input_scale @@ -518,93 +599,225 @@ def load_kv_scales(self, weights: List[Dict]): v_scale.append(w["v_scale"][...].reshape([])) return k_scale, v_scale - def load_weight_scales(self, weights: List[Dict]): - input_scale, weight_scale = [], [] - for w in weights: - if "input_scale" in w: - input_scale.append(w["input_scale"][...].reshape([])) - if "weight_scale" in w: - weight_scale.append(w["weight_scale"][...].reshape([])) - return input_scale, weight_scale + def load_weight_scales(self, + weights: List[Dict], + shard_keys: list[str] = None): + input_scales, weight_scales = {}, {} + if shard_keys is None: + for w in weights: + if "input_scale" in w: + input_scales[None] = w["input_scale"][...].reshape([]) + if "weight_scale" in w: + weight_scales[None] = w["weight_scale"][...].reshape([]) + else: + for shard_key, w in zip(shard_keys, weights): + if "input_scale" in w: + input_scales[shard_key] = w["input_scale"][...].reshape([]) + if "weight_scale" in w: + weight_scales[shard_key] = w["weight_scale"][...].reshape( + []) + return input_scales, weight_scales - def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: - load_weights_vanilla_helper(module, weights) + def load_weights_vanilla(self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False) -> None: + super().load_weights_vanilla( + module, weights, allow_partial_loading=allow_partial_loading) input_scale, weight_scale = self.load_weight_scales(weights) - if len(input_scale) != 0: - # Static quantization - copy_weight(module.input_scale, input_scale[0]) + if input_scale: + copy_weight(module.input_scale, input_scale[None]) module.inv_input_scale.data = 1.0 / module.input_scale - else: - # Dynamic quantization + setattr(module, "has_static_input_scale", True) + if weight_scale: + copy_weight(module.weight_scale, weight_scale[None]) + + def process_weights_after_loading_vanilla(self, module: Linear): + if not hasattr(module, "has_static_input_scale"): module.input_scale = None module.inv_input_scale = None - copy_weight(module.weight_scale, weight_scale[0]) + else: + delattr(module, "has_static_input_scale") - def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]) -> None: - q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) + def load_weights_fused_qkv_linear( + self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False) -> None: + """ + Load weights for fused QKV linear layer. - input_scale, weight_scale = self.load_weight_scales(weights) - if len(input_scale) != 0: - # Static quantization - copy_weight(module.input_scale, max(input_scale)) + In partial loading mode, only loads weights and scales to their designated positions. + The actual rescaling is deferred to process_weights_after_loading_fused_qkv_linear. + """ + # Parent class handles weight loading + super().load_weights_fused_qkv_linear( + module, weights, allow_partial_loading=allow_partial_loading) + weight_mode = module.weights_loading_config.weight_mode + if not hasattr(module, "tmp_input_scales"): + module.tmp_input_scales = torch.empty( + weight_mode.int_value, + dtype=torch.float32, + device=module.input_scale.device) + if not hasattr(module, "tmp_weight_scales"): + module.tmp_weight_scales = torch.empty( + weight_mode.int_value, + dtype=torch.float32, + device=module.weight_scale.device) + # Load input_scale and weight_scale to tmp_qkv_input_scales and tmp_qkv_weight_scales + # q -> index 0, k -> index 1, v -> index 2 + input_scales, weight_scales = self.load_weight_scales( + weights, shard_keys=weight_mode.shard_keys) + shard_key_to_index = weight_mode.shard_key_to_index + + for shard_key, scale in input_scales.items(): + idx = shard_key_to_index[shard_key] + module.tmp_input_scales[idx] = scale + setattr(module, "has_static_input_scale", True) + + for shard_key, scale in weight_scales.items(): + idx = shard_key_to_index[shard_key] + module.tmp_weight_scales[idx] = scale + + # Load k and v scales, used for NVFP4 KV cache + # Store them temporarily for post-processing + k_scale, v_scale = self.load_kv_scales(weights) + if k_scale: + if getattr(module, "tmp_k_scales", None) is None: + module.tmp_k_scales = [] + module.tmp_k_scales.extend(k_scale) + if v_scale: + if getattr(module, "tmp_v_scales", None) is None: + module.tmp_v_scales = [] + module.tmp_v_scales.extend(v_scale) + + def rescale_fused_weights(self, module: Linear): + """ + Helper function to rescale fused weights. + + This method: + 1. Computes the max input_scale of all shards(qkv or gate/up) and update input_scale parameter to the max value + 2. Computes the max weight_scale across all shards(qkv or gate/up) + 3. Rescales each weight shard: weight * (original_scale / max_scale) + 4. Updates weight_scale parameter to the unified max value + """ + weight_mode = module.weights_loading_config.weight_mode + shard_key_to_index = weight_mode.shard_key_to_index + + # Handle input_scale + if hasattr(module, "has_static_input_scale"): + # Compute max and replace input_scale with a new parameter + max_input_scale = module.tmp_input_scales.max() + module.input_scale.data.copy_(max_input_scale) + module.inv_input_scale.data = 1.0 / module.input_scale + delattr(module, "has_static_input_scale") else: - # Dynamic quantization module.input_scale = None + module.inv_input_scale = None - copy_weight(module.weight_scale, max(weight_scale)) + # Compute max weight_scale + max_weight_scale = module.tmp_weight_scales.max() + module.weight_scale.data.copy_(max_weight_scale) - # use in-place multiplication and division to avoid extra memory allocation - q_weight = q_weight.to(module.dtype).mul_(weight_scale[0]) - k_weight = k_weight.to(module.dtype).mul_(weight_scale[1]) - v_weight = v_weight.to(module.dtype).mul_(weight_scale[2]) + # Rescale each weight shard: weight * (original_scale / max_scale) + for shard_key in weight_mode.shard_keys: + idx = shard_key_to_index[shard_key] + original_scale = module.tmp_weight_scales[idx] - fused_weight = torch.cat((q_weight, k_weight, v_weight)) - fused_weight = fused_weight.div_( - module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn) - copy_weight(module.weight, fused_weight) + # Get shard position from mapping + shard_offset, shard_size = module.fused_weight_shard_indices_mapping[ + shard_key] - # Load k and v scales, used for NVFP4 KV cache - k_scale, v_scale = self.load_kv_scales(weights) - # NOTE: Currently the calibrated kv scales may cause overflow for certain input, disabling by default. + # Rescale: FP8 -> BF16 -> multiply by ratio -> FP8 + weight_shard = module.weight.data[shard_offset:shard_offset + + shard_size] + rescaled_weight = (weight_shard.to(module.dtype).mul_( + original_scale / max_weight_scale).to(torch.float8_e4m3fn)) + module.weight.data[shard_offset:shard_offset + + shard_size] = rescaled_weight + + delattr(module, "tmp_input_scales") + delattr(module, "tmp_weight_scales") + + def process_weights_after_loading_fused_qkv_linear(self, module: Linear): + """ + Post-process weights after all partial loads are complete. + """ + self.rescale_fused_weights(module) + + # Handle kv_scales for NVFP4 KV cache if os.environ.get("TRTLLM_LOAD_KV_SCALES", "0") == "1": - if len(k_scale) != 0: - assert len(v_scale) != 0 + k_scales = getattr(module, "tmp_k_scales", []) + v_scales = getattr(module, "tmp_v_scales", []) + if k_scales: + assert v_scales, "k_scale and v_scale must be loaded together" # The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448, # to avoid overflow when dequantizing NVFP4 in attention kernels. copy_weight( module.kv_scales, - torch.tensor( - [1.0, max(k_scale) * 6.0, - max(v_scale) * 6.0], - dtype=torch.float32)) + torch.tensor([ + 1.0, + max(k_scales).item() * 6.0, + max(v_scales).item() * 6.0 + ], + dtype=torch.float32)) module.inv_kv_scales.data = 1.0 / module.kv_scales - def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]) -> None: - input_scale, weight_scale = self.load_weight_scales(weights) - if len(input_scale) != 0: - # Static quantization - copy_weight(module.input_scale, max(input_scale)) - else: - # Dynamic quantization - module.input_scale = None - copy_weight(module.weight_scale, max(weight_scale)) + # Clean up temporary attributes + if hasattr(module, "tmp_k_scales"): + delattr(module, "tmp_k_scales") + if hasattr(module, "tmp_v_scales"): + delattr(module, "tmp_v_scales") - gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) + def load_weights_fused_gate_up_linear( + self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False) -> None: + """ + Load weights for fused gate/up linear layer. - # use in-place multiplication and division to avoid extra memory allocation - gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0]) - up_weight = up_weight.to(module.dtype).mul_(weight_scale[1]) - fused_weight = torch.cat((gate_weight, up_weight)) - fused_weight = fused_weight.div_( - module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn) - copy_weight(module.weight, fused_weight) + In partial loading mode, only loads weights and scales to their designated positions. + The actual rescaling is deferred to process_weights_after_loading_fused_gate_up_linear. + """ + # Parent class handles weight loading + super().load_weights_fused_gate_up_linear( + module, weights, allow_partial_loading=allow_partial_loading) + weight_mode = module.weights_loading_config.weight_mode + if not hasattr(module, "tmp_input_scales"): + module.tmp_input_scales = torch.empty( + weight_mode.int_value, + dtype=torch.float32, + device=module.input_scale.device) + if not hasattr(module, "tmp_weight_scales"): + module.tmp_weight_scales = torch.empty( + weight_mode.int_value, + dtype=torch.float32, + device=module.weight_scale.device) + # Load input_scale and weight_scale to their designated positions + # gate -> index 0, up -> index 1 + input_scales, weight_scales = self.load_weight_scales( + weights, shard_keys=weight_mode.shard_keys) + shard_key_to_index = weight_mode.shard_key_to_index + + for shard_key, scale in input_scales.items(): + idx = shard_key_to_index[shard_key] + module.tmp_input_scales[idx] = scale + setattr(module, "has_static_input_scale", True) + + for shard_key, scale in weight_scales.items(): + idx = shard_key_to_index[shard_key] + module.tmp_weight_scales[idx] = scale + + def process_weights_after_loading_fused_gate_up_linear( + self, module: Linear): + """ + Post-process weights after all partial loads are complete. + """ + self.rescale_fused_weights(module) -class FP8RowwiseLinearMethod(LinearMethodBase): +class FP8RowwiseLinearMethod(UnquantizedLinearMethod): def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): @@ -628,6 +841,8 @@ def create_weights(self, module: Linear, in_features: int, else: module.register_parameter("bias", None) + module.rebuild_tensor_metadata = {} + def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor]): # FP8 tensor inputs are from attention. Directly use ones as scale. @@ -661,51 +876,72 @@ def _get_scale_name(self, weights: List[Dict]): scale_name = "weight_scale" return scale_name - def load_weights_vanilla(self, module: Linear, weights: List[Dict]): - load_weights_vanilla_helper(module, weights) - + def load_weights_vanilla(self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False): + super().load_weights_vanilla( + module, weights, allow_partial_loading=allow_partial_loading) scale_name = self._get_scale_name(weights) - weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) - copy_weight(module.weight_scale, weight_scale) + if scale_name in weights[0]: + weight_scale = load_weight_shard(weights[0][scale_name], + module.tp_size, module.tp_rank, + module.tp_mode) + copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) module.inv_input_scale.data = 1.0 / module.input_scale - def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]): - q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) - fused_weight = torch.cat((q_weight, k_weight, v_weight)) - copy_weight(module.weight, fused_weight) - + def load_weights_fused_qkv_linear(self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False): + super().load_weights_fused_qkv_linear( + module, weights, allow_partial_loading=allow_partial_loading) scale_name = self._get_scale_name(weights) - q_scale = load_weight_shard(weights[0][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) - k_scale = load_weight_shard(weights[1][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) - v_scale = load_weight_shard(weights[2][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) - fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)) - copy_weight(module.weight_scale, fused_fp8_block_scale) - - def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]): - gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) - fused_weight = torch.cat((gate_weight, up_weight)) - copy_weight(module.weight, fused_weight) + q_scale = load_weight_shard( + weights[0][scale_name], module.tp_size, module.tp_rank, + module.tp_mode) if scale_name in weights[0] else None + k_scale = load_weight_shard( + weights[1][scale_name], module.tp_size, module.tp_rank, + module.tp_mode) if scale_name in weights[1] else None + v_scale = load_weight_shard( + weights[2][scale_name], module.tp_size, module.tp_rank, + module.tp_mode) if scale_name in weights[2] else None + for shard_key, scale in zip( + module.fused_weight_shard_indices_mapping.keys(), + [q_scale, k_scale, v_scale]): + if scale is not None: + shard_offset, shard_size = module.fused_weight_shard_indices_mapping[ + shard_key] + copy_weight_shard(module.weight_scale, scale, shard_offset, + shard_size) + def load_weights_fused_gate_up_linear( + self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False) -> None: + super().load_weights_fused_gate_up_linear( + module, weights, allow_partial_loading=allow_partial_loading) scale_name = self._get_scale_name(weights) - left_scale = load_weight_shard(weights[0][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) - right_scale = load_weight_shard(weights[1][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) - fused_scale = torch.cat((left_scale, right_scale)) - copy_weight(module.weight_scale, fused_scale) - - -class FP8BlockScalesLinearMethod(LinearMethodBase): + gate_scale = load_weight_shard( + weights[0][scale_name], module.tp_size, module.tp_rank, + module.tp_mode) if scale_name in weights[0] else None + up_scale = load_weight_shard( + weights[1][scale_name], module.tp_size, module.tp_rank, + module.tp_mode) if scale_name in weights[1] else None + for shard_key, scale in zip( + module.fused_weight_shard_indices_mapping.keys(), + [gate_scale, up_scale]): + if scale is not None: + shard_offset, shard_size = module.fused_weight_shard_indices_mapping[ + shard_key] + copy_weight_shard(module.weight_scale, scale, shard_offset, + shard_size) + + +class FP8BlockScalesLinearMethod(UnquantizedLinearMethod): def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): @@ -732,6 +968,8 @@ def create_weights(self, module: Linear, in_features: int, else: module.register_parameter("bias", None) + module.rebuild_tensor_metadata = {} + def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor]): if input.dtype == torch.float8_e4m3fn: @@ -770,75 +1008,107 @@ def apply(self, module: Linear, input: torch.Tensor, def _get_scale_name(self, weights: List[Dict]): # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe. # Actually they hold identical values of data_amax / 448. - scale_name = "weight_scale_inv" - if scale_name not in weights[0]: - scale_name = "weight_scale" - return scale_name + for w in weights: + if "weight_scale_inv" in w: + return "weight_scale_inv" + return "weight_scale" - def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: - load_weights_vanilla_helper(module, weights) + def load_weights_vanilla(self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False) -> None: + super().load_weights_vanilla( + module, weights, allow_partial_loading=allow_partial_loading) scale_name = self._get_scale_name(weights) - full_weight_scale = weights[0][scale_name] - # modelopt fp8_pb_wo can have 2 extra singleton dimensions - if full_weight_scale.dim() == 4: - full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1) - weight_scale = load_weight_shard(full_weight_scale, module.tp_size, - module.tp_rank, module.tp_mode) - copy_weight(module.weight_scale, weight_scale) + if scale_name in weights[0]: + full_weight_scale = weights[0][scale_name] + # modelopt fp8_pb_wo can have 2 extra singleton dimensions + if full_weight_scale.dim() == 4: + full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1) + weight_scale = load_weight_shard(full_weight_scale, module.tp_size, + module.tp_rank, module.tp_mode) + copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) module.inv_input_scale.data = 1.0 / module.input_scale - def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]) -> None: - q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) - fused_weight = torch.cat((q_weight, k_weight, v_weight)) + def remap_fused_shard_indices_by_divisible_factor(self, mapping: Dict, + divisible_factor: int): + """ + Remap fused weight shard indices to scale coordinates by dividing by divisible_factor. + + Args: + mapping: Dict of {shard_key: (offset, size)} in weight coordinates + divisible_factor: Block size (e.g., 128 for block-scale quantization) + + Returns: + Dict of {shard_key: (scale_offset, scale_size)} in scale coordinates + """ + result = {} + for key, (offset, size) in mapping.items(): + scale_offset = math.ceil(offset / divisible_factor) + scale_size = math.ceil(size / divisible_factor) + result[key] = (scale_offset, scale_size) + return result + + def load_weights_fused_qkv_linear( + self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False) -> None: + super().load_weights_fused_qkv_linear( + module, weights, allow_partial_loading=allow_partial_loading) scale_name = self._get_scale_name(weights) - full_q_scale = weights[0][scale_name] - full_k_scale = weights[1][scale_name] - full_v_scale = weights[2][scale_name] # modelopt fp8_pb_wo can have 2 extra singleton dimensions - if full_q_scale.dim() == 4: - full_q_scale = full_q_scale.squeeze(1).squeeze(-1) - if full_k_scale.dim() == 4: - full_k_scale = full_k_scale.squeeze(1).squeeze(-1) - if full_v_scale.dim() == 4: - full_v_scale = full_v_scale.squeeze(1).squeeze(-1) - q_scale = load_weight_shard(full_q_scale, module.tp_size, - module.tp_rank, module.tp_mode) - k_scale = load_weight_shard(full_k_scale, module.tp_size, - module.tp_rank, module.tp_mode) - v_scale = load_weight_shard(full_v_scale, module.tp_size, - module.tp_rank, module.tp_mode) - fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)) + full_scales = [ + w[scale_name] if scale_name in w else None for w in weights[:3] + ] + full_scales_squeezed = [ + s.squeeze(1).squeeze(-1) if s is not None and s.dim() == 4 else s + for s in full_scales + ] - copy_weight(module.weight, fused_weight) - copy_weight(module.weight_scale, fused_fp8_block_scale) + scales = [ + load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode) + if s is not None else None for s in full_scales_squeezed + ] + processed_mapping = self.remap_fused_shard_indices_by_divisible_factor( + module.fused_weight_shard_indices_mapping, 128) + for shard_key, scale in zip(processed_mapping.keys(), scales): + if scale is not None: + shard_offset, shard_size = processed_mapping[shard_key] + copy_weight_shard(module.weight_scale, scale, shard_offset, + shard_size) - def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]) -> None: - gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) - fused_weight = torch.cat((gate_weight, up_weight)) + def load_weights_fused_gate_up_linear( + self, + module: Linear, + weights: List[Dict], + allow_partial_loading: bool = False) -> None: + super().load_weights_fused_gate_up_linear( + module, weights, allow_partial_loading=allow_partial_loading) scale_name = self._get_scale_name(weights) - full_left_scale = weights[0][scale_name] - full_right_scale = weights[1][scale_name] - # modelopt fp8_pb_wo can have 2 extra singleton dimensions - if full_left_scale.dim() == 4: - full_left_scale = full_left_scale.squeeze(1).squeeze(-1) - if full_right_scale.dim() == 4: - full_right_scale = full_right_scale.squeeze(1).squeeze(-1) - left_scale = load_weight_shard(full_left_scale, module.tp_size, - module.tp_rank, module.tp_mode) - right_scale = load_weight_shard(full_right_scale, module.tp_size, - module.tp_rank, module.tp_mode) - fused_scale = torch.cat([left_scale, right_scale], dim=0) - copy_weight(module.weight, fused_weight) - copy_weight(module.weight_scale, fused_scale) + full_scales = [ + w[scale_name] if scale_name in w else None for w in weights[:2] + ] + full_scales_squeezed = [ + s.squeeze(1).squeeze(-1) if s is not None and s.dim() == 4 else s + for s in full_scales + ] + scales = [ + load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode) + if s is not None else None for s in full_scales_squeezed + ] + processed_mapping = self.remap_fused_shard_indices_by_divisible_factor( + module.fused_weight_shard_indices_mapping, 128) + for shard_key, scale in zip(processed_mapping.keys(), scales): + if scale is not None: + shard_offset, shard_size = processed_mapping[shard_key] + copy_weight_shard(module.weight_scale, scale, shard_offset, + shard_size) def post_load_weights(self, module: Linear): super().post_load_weights(module) @@ -847,17 +1117,19 @@ def post_load_weights(self, module: Linear): get_sm_version() == 120: weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, module.weight_scale) - transfromed_scale = transform_sf_into_required_layout( + transformed_scale = transform_sf_into_required_layout( weight_scale, mn=weight.shape[0], k=weight.shape[1], recipe=(1, 128, 128), is_sfa=False) - module.weight = nn.Parameter(weight, requires_grad=False) - module.weight_scale = nn.Parameter( - transfromed_scale, - requires_grad=False, - ) + replace_parameter_and_save_metadata( + module, "weight", nn.Parameter(weight, requires_grad=False), + module.rebuild_tensor_metadata) + replace_parameter_and_save_metadata( + module, "weight_scale", + nn.Parameter(transformed_scale, requires_grad=False), + module.rebuild_tensor_metadata) class NVFP4LinearMethod(LinearMethodBase): @@ -2293,5 +2565,14 @@ def load_weights(self, weight_mode, allow_partial_loading=allow_partial_loading) + def process_weights_after_loading(self): + self.quant_method.process_weights_after_loading(self) + def post_load_weights(self): self.quant_method.post_load_weights(self) + + def pre_reload_weights(self): + assert hasattr( + self.quant_method, "pre_reload_weights" + ), "pre_reload_weights is not supported for this quant method" + self.quant_method.pre_reload_weights(self) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 55276832ff7..1f3b1cc9e92 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -414,3 +414,15 @@ def maybe_compiled_copy_(dst, src): @maybe_compile def maybe_compiled_cat(tensors, dim): return torch.cat(tensors, dim) + + +def replace_parameter_and_save_metadata(module: torch.nn.Module, + param_name: str, + new_param: torch.nn.Parameter, + metadata_dict: Dict): + """ + Replace a parameter in a module and save the metadata of the original parameter. + """ + if param_name not in metadata_dict: + metadata_dict[param_name] = getattr(module, param_name).to("meta") + module.register_parameter(param_name, new_param) diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index ce6eaa5b4ff..453e4d4b972 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -50,6 +50,13 @@ def update_weights(self, ipc_handles: Optional[dict] = None): Exception: Re-raises any exception encountered during weight update. """ try: + if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "pre_reload_weights") and not getattr( + module, "_weights_removed", False + ): + module.pre_reload_weights() + setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) if ipc_handles is not None: logger.info("Update weights from IPC handles") device_uuid = get_device_uuid(self.device_id) @@ -82,6 +89,10 @@ def update_weights(self, ipc_handles: Optional[dict] = None): else: logger.info("Finalize update weights") for module in self.engine.model_engine.model.modules(): + if hasattr(module, "process_weights_after_loading") and not getattr( + module, "_weights_removed", False + ): + module.process_weights_after_loading() if hasattr(module, "post_load_weights") and not getattr( module, "_weights_removed", False ): @@ -93,6 +104,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): moe_load_balancer.finalize_model() logger.info("moe_load_balancer finalize model done") self.engine.reset_prefix_cache() + delattr(self.engine.model_engine.model, "first_pre_reload_weights") except Exception as e: logger.error("Encountered an error in update_weights") diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.py index 765bd7f5f40..e1cb8bb1ce5 100644 --- a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.py +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.py @@ -15,112 +15,21 @@ import asyncio import os from functools import partial -from typing import List, Tuple +from typing import List import pytest import ray import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer from utils.llm_data import llm_models_root +from utils.torch_ref import RefHFModel from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams -class HFModel: - def __init__(self, model_name: str, device_id: int): - self.device_id = device_id - self.model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.bfloat16 - ).to(f"cuda:{device_id}") - - def generate_batch_with_padding( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - responses: List[List[int]], - prompt_max_len: int = 1024, - micro_batch_size: int = 16, - ): - """ - Synchronous inference on a batch with micro-batching. - Directly extracts response logprobs to save memory. - - Args: - input_ids: [batch_size, seq_len] - attention_mask: [batch_size, seq_len] - position_ids: [batch_size, seq_len] - responses: List of response token IDs for each sample - prompt_max_len: Maximum prompt length (default 1024) - micro_batch_size: Size of each micro batch to avoid OOM - - Returns: - List of logprobs tensors, one per sample [response_len] - """ - # Move tensors to the correct device - input_ids = input_ids.to(f"cuda:{self.device_id}") - attention_mask = attention_mask.to(f"cuda:{self.device_id}") - position_ids = position_ids.to(f"cuda:{self.device_id}") - - batch_size = input_ids.shape[0] - num_micro_batches = (batch_size + micro_batch_size - 1) // micro_batch_size - - all_response_logprobs = [] - - with torch.no_grad(): - for micro_idx in range(num_micro_batches): - start_idx = micro_idx * micro_batch_size - end_idx = min((micro_idx + 1) * micro_batch_size, batch_size) - - # Extract micro batch - micro_input_ids = input_ids[start_idx:end_idx] - micro_attention_mask = attention_mask[start_idx:end_idx] - micro_position_ids = position_ids[start_idx:end_idx] - - # Forward pass - outputs = self.model( - input_ids=micro_input_ids, - attention_mask=micro_attention_mask, - position_ids=micro_position_ids, - ) - - # Extract response logprobs for each sample in this micro batch - micro_logits = outputs.logits # [micro_batch_size, seq_len, vocab_size] - - for i in range(micro_logits.shape[0]): - sample_idx = start_idx + i - response = responses[sample_idx] - response_len = len(response) - - # Extract logits for predicting response tokens - # For predicting response[j], we need logits at position prompt_max_len-1+j - response_logits = micro_logits[ - i, prompt_max_len - 1 : prompt_max_len - 1 + response_len, : - ] - - # Convert to logprobs - response_logprobs = torch.log_softmax(response_logits, dim=-1) - - # Extract logprobs for the actual generated tokens - response_tensor = torch.tensor( - response, dtype=torch.long, device=response_logprobs.device - ) - ref_logprob_for_tokens = torch.gather( - response_logprobs, dim=-1, index=response_tensor.unsqueeze(-1) - ).squeeze(-1) - - all_response_logprobs.append(ref_logprob_for_tokens) - - # Free memory immediately after processing each micro batch - del outputs, micro_logits - torch.cuda.empty_cache() - - return all_response_logprobs - - async def generate_batch_async( - hf_model: HFModel, + hf_model: RefHFModel, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, @@ -133,7 +42,7 @@ async def generate_batch_async( Runs the synchronous model inference in a thread pool. Args: - hf_model: HFModel instance + hf_model: RefHFModel instance input_ids: Input token IDs attention_mask: Attention mask position_ids: Position IDs @@ -163,89 +72,6 @@ async def generate_batch_async( return result -def pad_data( - original_prompts: List[List[int]], - generated_token_ids_list: List[List[int]], - prompt_max_len: int = 1024, - response_max_len: int = 1024, - pad_token_id: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Pad the data to the maximum length. - - Structure: - [left_pad | actual_prompt | actual_response | right_pad] - |<-- prompt_max_len=1024 -->|<-- response_max_len=1024 -->| - - Args: - original_prompts: List of prompt token IDs, len = batch_size - generated_token_ids_list: List of response token IDs, len = batch_size - prompt_max_len: Maximum length for prompt section (default 1024) - response_max_len: Maximum length for response section (default 1024) - pad_token_id: Token ID for padding (default 0) - Returns: - input_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len] - attention_mask: Tensor of shape [batch_size, prompt_max_len + response_max_len] - position_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len] - """ - batch_size = len(original_prompts) - total_len = prompt_max_len + response_max_len - - for i, (prompt, response) in enumerate(zip(original_prompts, generated_token_ids_list)): - assert len(prompt) <= prompt_max_len, ( - f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}" - ) - assert len(response) <= response_max_len, ( - f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}" - ) - - # Build batch tensors [batch_size, 2048] - batch_input_ids = torch.full( - (batch_size, total_len), pad_token_id, dtype=torch.long, device="cuda" - ) - batch_attention_mask = torch.zeros((batch_size, total_len), dtype=torch.long, device="cuda") - batch_position_ids = torch.zeros((batch_size, total_len), dtype=torch.long, device="cuda") - - response_lens = [] - - for i in range(batch_size): - prompt_tokens = original_prompts[i] - response_tokens = generated_token_ids_list[i] - - prompt_len = len(prompt_tokens) - response_len = len(response_tokens) - response_lens.append(response_len) - - left_pad_len = prompt_max_len - prompt_len - - # Fill input_ids: [left_pad | prompt | response | right_pad] - prompt_start = left_pad_len - prompt_end = prompt_max_len - response_start = prompt_max_len - response_end = prompt_max_len + response_len - - batch_input_ids[i, prompt_start:prompt_end] = torch.tensor( - prompt_tokens, dtype=torch.long, device="cuda" - ) - batch_input_ids[i, response_start:response_end] = torch.tensor( - response_tokens, dtype=torch.long, device="cuda" - ) - - # Fill attention_mask: 1 for actual tokens, 0 for padding - batch_attention_mask[i, prompt_start:response_end] = 1 - - # Fill position_ids: sequential for actual tokens - actual_seq_len = prompt_len + response_len - batch_position_ids[i, prompt_start:response_end] = torch.arange( - actual_seq_len, dtype=torch.long, device="cuda" - ) - # Right padding keeps the last position value - if response_len < response_max_len: - batch_position_ids[i, response_end:] = actual_seq_len - 1 - - return batch_input_ids, batch_attention_mask, batch_position_ids - - def compare_logprobs(logprobs_list, ref_new_token_logprobs_list): """ logprobs_list: List[torch.Tensor] - LLM logprob values @@ -337,7 +163,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str ray.shutdown() torch.cuda.empty_cache() - input_ids, attention_mask, position_ids = pad_data(test_prompts, llm_responses) + input_ids, attention_mask, position_ids = RefHFModel.pad_data(test_prompts, llm_responses) # Split data across GPUs num_gpus = 4 @@ -347,7 +173,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str dp_hf_models = [] for device_id in range(num_gpus): - hf_model = HFModel(model_dir, device_id) + hf_model = RefHFModel(model_dir, device_id) dp_hf_models.append(hf_model) # Split input data and responses into chunks for each GPU @@ -367,7 +193,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str responses_chunks.append(llm_responses[start_idx:end_idx]) # Process each chunk on its corresponding GPU asynchronously - async def process_all_chunks(hf_models: List[HFModel]): + async def process_all_chunks(hf_models: List[RefHFModel]): tasks = [] for i, (input_chunk, attn_chunk, pos_chunk, resp_chunk) in enumerate( zip(input_ids_chunks, attention_mask_chunks, position_ids_chunks, responses_chunks) diff --git a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py index 96e88226122..a7012d21d26 100644 --- a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py +++ b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py @@ -1,62 +1,64 @@ -from typing import Callable, List, Optional +import json +import os +import re +import shutil +from tempfile import TemporaryDirectory +from typing import Callable, List, Optional, Tuple import pytest import torch from torch.multiprocessing.reductions import reduce_tensor -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from utils.llm_data import llm_models_root +from utils.torch_ref import RefHFModel from tensorrt_llm import LLM +from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams -class HFModel: - def __init__(self, model_name: str): +class RefHFModelWithIPCHandles(RefHFModel): + def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4): + self.device_id = device_id + config = AutoConfig.from_pretrained(model_dir) + config.num_hidden_layers = num_hidden_layers self.model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.bfloat16 - ).to("cuda") - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.cuda_device = torch.cuda.current_device() + model_dir, config=config, torch_dtype=torch.bfloat16, attn_implementation="eager" + ).to(f"cuda:{device_id}") self.all_weights = {} - self.device_uuid = [HFModel.get_device_uuid(i) for i in range(torch.cuda.device_count())] + self.device_uuid = [get_device_uuid(i) for i in range(torch.cuda.device_count())] self._replicate_weights() - @staticmethod - def get_device_uuid(cuda_device: int): - from tensorrt_llm._torch.utils import get_device_uuid - - return get_device_uuid(cuda_device) - def _replicate_weights(self): model_weights = [] for n, p in self.model.named_parameters(): model_weights.append((n, p.detach().clone())) - self.all_weights[self.cuda_device] = model_weights + self.all_weights[self.device_id] = model_weights for i in range(torch.cuda.device_count()): - if i != self.cuda_device: + if i != self.device_id: cur_weights = [] - for n, p in self.all_weights[self.cuda_device]: + for n, p in self.all_weights[self.device_id]: cur_weights.append((n, p.to("cuda:" + str(i)))) self.all_weights[i] = cur_weights def get_weight_ipc_handles( self, - cuda_device: Optional[List[int]] = None, + device_ids: Optional[List[int]] = None, weight_filter: Optional[Callable[[str], bool]] = None, ): """ Get IPC handles for model weights with flexible filtering. Args: - cuda_device: List of CUDA device indices to get weights from + device_ids: List of device indices to get weights from weight_filter: Optional function that takes weight name and returns True if weight should be included Returns: ret: Dictionary containing weight handles """ ret = {} - device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device + device_list = list(range(torch.cuda.device_count())) if device_ids is None else device_ids for device in device_list: all_handles = [] @@ -71,50 +73,6 @@ def get_weight_ipc_handles( return ret - def generate_batch_incremental( - self, original_prompts: List[str], generated_token_ids_list: List[List[int]] - ): - """ - Generate tokens incrementally for each prompt in the batch: [prompt, prompt+token0, prompt+token0+token1, ...] - """ - logits_list = [] - - for i in range(len(original_prompts)): - base_token_ids = self.tokenizer.encode(original_prompts[i], return_tensors="pt")[0].to( - "cuda" - ) - cur_logits = [] - for j in range(len(generated_token_ids_list[i])): - if j > 0: - cur_gen_tokens = torch.tensor(generated_token_ids_list[i][:j]).to("cuda") - cur_token_ids = torch.cat([base_token_ids, cur_gen_tokens], dim=-1) - else: - cur_token_ids = base_token_ids - - ret = self.model.generate( - input_ids=cur_token_ids.unsqueeze(0).cuda(), - max_new_tokens=1, - do_sample=False, - return_dict_in_generate=True, - output_scores=True, - ) - - cur_logits.append(ret["scores"][0]) - cur_logits = torch.stack(cur_logits, dim=0) - logits_list.append(cur_logits.squeeze(1)) - - return logits_list - - -def extract_tokens_from_outputs(outputs): - """Extract individual tokens from LLM outputs using token IDs directly""" - tokens_list = [] - for output in outputs: - # Get token IDs directly from the output - token_ids = output.outputs[0].token_ids - tokens_list.append(token_ids) - return tokens_list - def compare_logits( logits_list: List[torch.Tensor], @@ -123,7 +81,6 @@ def compare_logits( threshold: float = 0.9, ): assert len(logits_list) == len(ref_logits_list) - for i in range(len(logits_list)): assert logits_list[i].shape == ref_logits_list[i].shape lhs_idx = torch.topk(logits_list[i], topk, dim=-1).indices @@ -142,119 +99,168 @@ def compare_logits( ) -def run_generate(llm, hf_model, prompts, sampling_params): - outputs = llm.generate(prompts, sampling_params) +def run_generate( + llm: LLM, + hf_model: RefHFModel, + prompts: List[List[int]], + sampling_params: SamplingParams, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + llm_responses = [] llm_logits = [] + outputs = llm.generate(prompts, sampling_params) for output in outputs: llm_logits.append(output.outputs[0].generation_logits) - - generated_token_ids_list = extract_tokens_from_outputs(outputs) - ref_logits = hf_model.generate_batch_incremental(prompts, generated_token_ids_list) + llm_responses.append(output.outputs[0].token_ids) + input_ids, attention_mask, position_ids = RefHFModel.pad_data(prompts, llm_responses) + ref_logits = hf_model.generate_batch_with_padding( + input_ids, attention_mask, position_ids, llm_responses, return_logits=True + ) return llm_logits, ref_logits +def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4): + if os.path.exists(dst_folder): + shutil.rmtree(dst_folder) + os.makedirs(dst_folder) + + for root, dirs, files in os.walk(src_folder): + rel_path = os.path.relpath(root, src_folder) + dest_dir = os.path.join(dst_folder, rel_path) + + if not os.path.exists(dest_dir): + os.makedirs(dest_dir) + + for file in files: + src_path = os.path.join(root, file) + dest_path = os.path.join(dest_dir, file) + if "safetensor" in file: + continue + + if file == "config.json": + with open(src_path, "r", encoding="utf-8") as f: + config = json.load(f) + config["num_hidden_layers"] = num_hidden_layers + with open(dest_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + else: + shutil.copy2(src_path, dest_path) + + @pytest.mark.parametrize( "model_dir", - ["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"], + [ + "llama-models-v2/TinyLlama-1.1B-Chat-v1.0", + "Qwen2.5-0.5B-Instruct", + "Qwen3/Qwen3-8B", + "Qwen3/Qwen3-30B-A3B", + "Qwen3/Qwen3-8B-FP8", + "Qwen3/Qwen3-30B-A3B-FP8", + ], ) def test_llm_update_weights(model_dir): model_dir = str(llm_models_root() / model_dir) - kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) - - hf_model = HFModel(model_dir) - - llm = LLM( - model=model_dir, - ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", - tensor_parallel_size=1, - load_format="dummy", - pipeline_parallel_size=1, - kv_cache_config=kv_cache_config, - ) - - # Generate texts from the prompts. - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] + with TemporaryDirectory() as tmp_model_dir: + num_hidden_layers = 1 + process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers) + hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) + tokenizer = AutoTokenizer.from_pretrained(model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + llm = LLM( + model=tmp_model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=1, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + ) - sampling_params = SamplingParams(temperature=0, return_generation_logits=True) + # Generate texts from the prompts. + prompts_texts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] + del tokenizer + sampling_params = SamplingParams( + temperature=0, return_generation_logits=True, max_tokens=1024 + ) - ipc_handles = hf_model.get_weight_ipc_handles([0]) + ipc_handles = hf_model.get_weight_ipc_handles([0]) - llm._collective_rpc("update_weights", (ipc_handles,)) - # Finalize the update weights - llm._collective_rpc("update_weights", (None,)) + llm._collective_rpc("update_weights", (ipc_handles,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) - llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) - compare_logits(llm_logits, ref_logits) + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits) @pytest.mark.parametrize( "model_dir", - ["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"], + [ + "llama-models-v2/TinyLlama-1.1B-Chat-v1.0", + "Qwen2.5-0.5B-Instruct", + "Qwen3/Qwen3-8B", + "Qwen3/Qwen3-30B-A3B", + "Qwen3/Qwen3-8B-FP8", + "Qwen3/Qwen3-30B-A3B-FP8", + ], ) def test_llm_partial_update_weights(model_dir): model_dir = str(llm_models_root() / model_dir) - kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) - - hf_model = HFModel(model_dir) - - llm = LLM( - model=model_dir, - ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", - tensor_parallel_size=1, - load_format="dummy", - pipeline_parallel_size=1, - kv_cache_config=kv_cache_config, - ) + with TemporaryDirectory() as tmp_model_dir: + num_hidden_layers = 1 + process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers) + hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) + tokenizer = AutoTokenizer.from_pretrained(model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + + llm = LLM( + model=tmp_model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=1, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + ) - # Generate texts from the prompts. - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - sampling_params = SamplingParams(temperature=0, return_generation_logits=True) - - ipc_handles = hf_model.get_weight_ipc_handles([0]) - - def common_filter(filter_name: str) -> Callable[[str], bool]: - def filter_fn(name: str) -> bool: - return filter_name in name - - return filter_fn - - filter_list = [ - "q_proj.weight", - "k_proj.weight", - "v_proj.weight", - "o_proj.weight", - "gate_proj.weight", - "up_proj.weight", - "down_proj.weight", - "norm.weight", - "embed_tokens.weight", - "lm_head.weight", - ] - if "Qwen2.5" in model_dir or "Qwen2" in model_dir: - filter_list.extend( - [ - "q_proj.bias", - "k_proj.bias", - "v_proj.bias", - ] + # Generate texts from the prompts. + prompts_texts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] + del tokenizer + + sampling_params = SamplingParams( + temperature=0, return_generation_logits=True, max_tokens=1024 ) - for filter_name in filter_list: - weight_filter = common_filter(filter_name=filter_name) - ipc_handles = hf_model.get_weight_ipc_handles([0], weight_filter=weight_filter) - llm._collective_rpc("update_weights", (ipc_handles,), non_block=True) - # Finalize the update weights - llm._collective_rpc("update_weights", (None,)) - - llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) - compare_logits(llm_logits, ref_logits) + + def common_filter(filter_name: str) -> Callable[[str], bool]: + def filter_fn(name: str) -> bool: + return name.endswith(filter_name) + + return filter_fn + + # Generate filter_list from model weight keys by removing layer prefix + # e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight" + layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.") + filter_set = set() + for name, _ in hf_model.all_weights[hf_model.device_id]: + suffix = layer_prefix_pattern.sub("", name) + filter_set.add(suffix) + filter_list = list(filter_set) + + for filter_name in filter_list: + weight_filter = common_filter(filter_name=filter_name) + ipc_handles = hf_model.get_weight_ipc_handles([0], weight_filter=weight_filter) + llm._collective_rpc("update_weights", (ipc_handles,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) + + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits) diff --git a/tests/unittest/utils/torch_ref.py b/tests/unittest/utils/torch_ref.py index d8a6b258c57..7b93026ee04 100644 --- a/tests/unittest/utils/torch_ref.py +++ b/tests/unittest/utils/torch_ref.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat +from transformers import AutoModelForCausalLM def geglu(x): @@ -1239,3 +1240,188 @@ def forward_impl( x = self.linear_out(x) return x, conv_state, lru_state + + +class RefHFModel: + + def __init__(self, + model_dir: str, + device_id: int = 0, + additional_model_kargs: Optional[Dict[str, Any]] = None): + self.device_id = device_id + self.model = AutoModelForCausalLM.from_pretrained( + model_dir, **(additional_model_kargs or {})).to(f"cuda:{device_id}") + + def generate_batch_with_padding( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + responses: List[List[int]], + prompt_max_len: int = 1024, + micro_batch_size: int = 16, + return_logits: bool = False, + ): + """ + Synchronous inference on a batch with micro-batching. + Directly extracts response logprobs to save memory. + + Args: + input_ids: [batch_size, seq_len] + attention_mask: [batch_size, seq_len] + position_ids: [batch_size, seq_len] + responses: List of response token IDs for each sample + prompt_max_len: Maximum prompt length (default 1024) + micro_batch_size: Size of each micro batch to avoid OOM + return_logits: Whether to return logits, If True, return logits, otherwise return logprobs + Returns: + List of logits or logprobs tensors, one per sample [response_len] + """ + # Move tensors to the correct device + input_ids = input_ids.to(f"cuda:{self.device_id}") + attention_mask = attention_mask.to(f"cuda:{self.device_id}") + position_ids = position_ids.to(f"cuda:{self.device_id}") + + batch_size = input_ids.shape[0] + num_micro_batches = (batch_size + micro_batch_size - + 1) // micro_batch_size + + ref_results = [] + + with torch.no_grad(): + for micro_idx in range(num_micro_batches): + start_idx = micro_idx * micro_batch_size + end_idx = min((micro_idx + 1) * micro_batch_size, batch_size) + + # Extract micro batch + micro_input_ids = input_ids[start_idx:end_idx] + micro_attention_mask = attention_mask[start_idx:end_idx] + micro_position_ids = position_ids[start_idx:end_idx] + + # Forward pass + outputs = self.model( + input_ids=micro_input_ids, + attention_mask=micro_attention_mask, + position_ids=micro_position_ids, + ) + # Extract response logprobs for each sample in this micro batch + for i in range(outputs.logits.shape[0]): + sample_idx = start_idx + i + response = responses[sample_idx] + response_len = len(response) + + # Extract logits for predicting response tokens + # For predicting response[j], we need logits at position prompt_max_len-1+j + response_logits = outputs.logits[i, prompt_max_len - + 1:prompt_max_len - 1 + + response_len, :] + if return_logits: + ref_results.append(response_logits) + else: + # Convert to logprobs + response_logprobs = torch.log_softmax(response_logits, + dim=-1) + + # Extract logprobs for the actual generated tokens + response_tensor = torch.tensor( + response, + dtype=torch.long, + device=response_logprobs.device) + ref_logprob_for_tokens = torch.gather( + response_logprobs, + dim=-1, + index=response_tensor.unsqueeze(-1)).squeeze(-1) + + ref_results.append(ref_logprob_for_tokens) + + # Free memory immediately after processing each micro batch + del outputs + torch.cuda.empty_cache() + + return ref_results + + @staticmethod + def pad_data( + original_prompts: List[List[int]], + generated_token_ids_list: List[List[int]], + prompt_max_len: int = 1024, + response_max_len: int = 1024, + pad_token_id: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Pad the data to the maximum length. + + Structure: + [left_pad | actual_prompt | actual_response | right_pad] + |<-- prompt_max_len=1024 -->|<-- response_max_len=1024 -->| + + Args: + original_prompts: List of prompt token IDs, len = batch_size + generated_token_ids_list: List of response token IDs, len = batch_size + prompt_max_len: Maximum length for prompt section (default 1024) + response_max_len: Maximum length for response section (default 1024) + pad_token_id: Token ID for padding (default 0) + Returns: + input_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len] + attention_mask: Tensor of shape [batch_size, prompt_max_len + response_max_len] + position_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len] + """ + batch_size = len(original_prompts) + total_len = prompt_max_len + response_max_len + + for i, (prompt, response) in enumerate( + zip(original_prompts, generated_token_ids_list)): + assert len(prompt) <= prompt_max_len, ( + f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}" + ) + assert len(response) <= response_max_len, ( + f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}" + ) + + # Build batch tensors [batch_size, total_len] + batch_input_ids = torch.full((batch_size, total_len), + pad_token_id, + dtype=torch.long, + device="cuda") + batch_attention_mask = torch.zeros((batch_size, total_len), + dtype=torch.long, + device="cuda") + batch_position_ids = torch.zeros((batch_size, total_len), + dtype=torch.long, + device="cuda") + + response_lens = [] + + for i in range(batch_size): + prompt_tokens = original_prompts[i] + response_tokens = generated_token_ids_list[i] + + prompt_len = len(prompt_tokens) + response_len = len(response_tokens) + response_lens.append(response_len) + + left_pad_len = prompt_max_len - prompt_len + + # Fill input_ids: [left_pad | prompt | response | right_pad] + prompt_start = left_pad_len + prompt_end = prompt_max_len + response_start = prompt_max_len + response_end = prompt_max_len + response_len + + batch_input_ids[i, prompt_start:prompt_end] = torch.tensor( + prompt_tokens, dtype=torch.long, device="cuda") + batch_input_ids[i, response_start:response_end] = torch.tensor( + response_tokens, dtype=torch.long, device="cuda") + + # Fill attention_mask: 1 for actual tokens, 0 for padding + batch_attention_mask[i, prompt_start:response_end] = 1 + + # Fill position_ids: sequential for actual tokens + actual_seq_len = prompt_len + response_len + batch_position_ids[i, prompt_start:response_end] = torch.arange( + actual_seq_len, dtype=torch.long, device="cuda") + # Right padding keeps the last position value + if response_len < response_max_len: + batch_position_ids[i, response_end:] = actual_seq_len - 1 + + return batch_input_ids, batch_attention_mask, batch_position_ids