@@ -68,7 +68,10 @@ def load_weight_shard(
6868 tensor_parallel_mode : Optional [TensorParallelMode ] = None ,
6969 device : torch .device = torch .device ('cpu' ),
7070 return_slice_indices : bool = False ,
71+ weight_name : Optional [str ] = None ,
7172) -> torch .Tensor :
73+ if weight_name is not None :
74+ print (f"[load_weight_shard] weight_name: { weight_name } " )
7275 # Skip device transfers on integrated GPUs to conserve shared memory
7376 if weight .device .type != device .type and is_device_integrated ():
7477 # For integrated GPU systems (e.g., DGX Spark), CPU and GPU share limited physical memory.
@@ -112,6 +115,10 @@ def maybe_convert_to_torch_tensor(
112115 if width == 1 :
113116 return maybe_convert_to_torch_tensor (weight )
114117
118+ if weight_name is not None and tensor_parallel_rank == 1 :
119+ print (f"[load_weight_shard] THIS IS WHERE YOU SWAP RANK 1." )
120+ if weight_name is not None and tensor_parallel_rank == 2 :
121+ print (f"[load_weight_shard] THIS IS WHERE YOU SWAP RANK 2." )
115122 slice_width = math .ceil (width / tensor_parallel_size )
116123 slice_start = tensor_parallel_rank * slice_width
117124 slice_end = min ((tensor_parallel_rank + 1 ) * slice_width , width )
@@ -140,7 +147,10 @@ def load_weights_vanilla_helper(module: Linear,
140147 weights : List [Dict ],
141148 weight_transform = lambda x : x ,
142149 bias_transform = lambda x : x ,
143- allow_partial_loading : bool = False ):
150+ allow_partial_loading : bool = False ,
151+ weight_name : Optional [str ] = None ):
152+ if weight_name is not None :
153+ print (f"[load_weights_vanilla_helper] weight_name: { weight_name } " )
144154 assert len (weights ) == 1
145155 if not allow_partial_loading :
146156 assert "weight" in weights [0 ]
@@ -150,7 +160,7 @@ def load_weights_vanilla_helper(module: Linear,
150160
151161 weight = load_weight_shard (weights [0 ]['weight' ], module .tp_size ,
152162 module .tp_rank , module .tp_mode ,
153- device ) if "weight" in weights [0 ] else None
163+ device , weight_name = weight_name ) if "weight" in weights [0 ] else None
154164
155165 if weight is not None :
156166 if module .has_weight_only_quant :
@@ -167,7 +177,7 @@ def load_weights_vanilla_helper(module: Linear,
167177 if module .bias is not None :
168178 bias = load_weight_shard (weights [0 ]['bias' ], module .tp_size ,
169179 module .tp_rank , module .tp_mode ,
170- device ) if "bias" in weights [0 ] else None
180+ device , weight_name = weight_name ) if "bias" in weights [0 ] else None
171181 if bias is not None :
172182 copy_weight (module .bias , bias_transform (bias ))
173183
@@ -311,7 +321,8 @@ def load_weights(self,
311321 module : Linear ,
312322 weights : List [Dict ],
313323 weight_mode : WeightMode ,
314- allow_partial_loading : bool = False ):
324+ allow_partial_loading : bool = False ,
325+ weight_name : Optional [str ] = None ):
315326 """
316327 Load weights from the checkpoint.
317328 """
@@ -396,10 +407,14 @@ def apply(self, module: Linear, input: torch.Tensor,
396407 def load_weights_vanilla (self ,
397408 module : Linear ,
398409 weights : List [Dict ],
399- allow_partial_loading : bool = False ) -> None :
410+ allow_partial_loading : bool = False ,
411+ weight_name : Optional [str ] = None ) -> None :
412+ if weight_name is not None :
413+ print (f"[UnquantizedLinearMethod::load_weights_vanilla] weight_name: { weight_name } " )
400414 load_weights_vanilla_helper (module ,
401415 weights ,
402- allow_partial_loading = allow_partial_loading )
416+ allow_partial_loading = allow_partial_loading ,
417+ weight_name = weight_name )
403418
404419 def load_weights_fused_qkv_linear (
405420 self ,
@@ -2058,6 +2073,8 @@ def __init__(
20582073 disable_deep_gemm : bool = False ,
20592074 fused_weight_shard_indices_mapping : Optional [dict ] = None ,
20602075 nvfp4_allowed_backends : Optional [List [str ]] = None ,
2076+ weight_name : Optional [str ] = None ,
2077+ mapping_with_cp : Optional [Mapping ] = None ,
20612078 ):
20622079 """
20632080 Args:
@@ -2098,6 +2115,12 @@ def __init__(
20982115 'cutlass' , 'cublaslt' , 'cuda_core'
20992116 ]
21002117
2118+ if mapping_with_cp is not None and weight_name == "o_proj" :
2119+ print ("[Linear::__init__] Found o_proj with CP mapping. Setting weight_name to o_proj_with_cp." )
2120+ self .weight_name = "o_proj_with_cp"
2121+ else :
2122+ self .weight_name = None
2123+
21012124 local_in_features = in_features
21022125 local_out_features = out_features
21032126
@@ -2284,7 +2307,8 @@ def load_weights(self,
22842307 self ,
22852308 weights ,
22862309 weight_mode ,
2287- allow_partial_loading = allow_partial_loading )
2310+ allow_partial_loading = allow_partial_loading ,
2311+ weight_name = self .weight_name )
22882312
22892313 def post_load_weights (self ):
22902314 self .quant_method .post_load_weights (self )
0 commit comments