@@ -269,23 +269,28 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
269
269
QUANTIZATION_NVFP4_AWQ ,
270
270
QUANTIZATION_W4A8_NVFP4_FP8 ,
271
271
]:
272
+ # If scale is already registered, indicates weights are already compressed.
273
+ # We convert to modelopt scale if necessary and return
272
274
if hasattr (weight_quantizer , "_scale" ):
273
275
return NVFP4QTensor .get_modelopt_weights_scaling_factor (
274
276
weight_quantizer ._scale , weight .metadata ["shape" ]
275
277
)
276
-
277
- return NVFP4QTensor .get_weights_scaling_factor (
278
- weight ,
279
- weight_quantizer .block_sizes [- 1 ],
280
- NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer ).to (
281
- weight .device
282
- ),
283
- )[0 ]
278
+ else :
279
+ return NVFP4QTensor .get_weights_scaling_factor (
280
+ weight ,
281
+ weight_quantizer .block_sizes [- 1 ],
282
+ NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer ).to (
283
+ weight .device
284
+ ),
285
+ )[0 ]
284
286
285
287
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8 , QUANTIZATION_MXFP4 ]:
286
- return MXFP4QTensor .quantize (weight , block_size = weight_quantizer .block_sizes [- 1 ])[
287
- 1
288
- ].reshape (* weight .shape [:- 1 ], - 1 )
288
+ if hasattr (weight_quantizer , "_scale" ):
289
+ return weight_quantizer ._scale .reshape (* weight .shape [:- 1 ], - 1 )
290
+ else :
291
+ return MXFP4QTensor .quantize (weight , block_size = weight_quantizer .block_sizes [- 1 ])[
292
+ 1
293
+ ].reshape (* weight .shape [:- 1 ], - 1 )
289
294
return get_scaling_factor (weight_quantizer )
290
295
291
296
@@ -301,7 +306,10 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
301
306
QUANTIZATION_NVFP4_AWQ ,
302
307
QUANTIZATION_W4A8_NVFP4_FP8 ,
303
308
]:
304
- return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
309
+ if hasattr (weight_quantizer , "_double_scale" ):
310
+ return weight_quantizer ._double_scale
311
+ else :
312
+ return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
305
313
306
314
# SequentialQuantizer is required
307
315
if not isinstance (weight_quantizer , SequentialQuantizer ) or not weight_quantizer [- 1 ].is_enabled :
@@ -818,7 +826,12 @@ def from_quantized_weight(
818
826
raise NotImplementedError (f"quantization format { quantization } not supported" )
819
827
820
828
821
- def postprocess_state_dict (state_dict : dict , maxbound : float , quantization : str | None ) -> dict :
829
+ def postprocess_state_dict (
830
+ state_dict : dict ,
831
+ maxbound : float ,
832
+ quantization : str | None ,
833
+ is_modelopt_trained_lora : bool = False ,
834
+ ) -> dict :
822
835
"""Filters out keys related to weight quantizers and updates KV cache related keys.
823
836
824
837
Args:
@@ -835,11 +848,18 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
835
848
"k_bmm_quantizer._bias_value" : "k_proj.k_bias" ,
836
849
"v_bmm_quantizer._bias_value" : "v_proj.v_bias" ,
837
850
"input_quantizer._pre_quant_scale" : "pre_quant_scale" ,
838
- "base_layer.weight" : "weight" ,
839
- "base_layer.input_scale" : "input_scale" ,
840
- "base_layer.weight_scale" : "weight_scale" ,
841
851
}
842
852
853
+ # For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
854
+ if is_modelopt_trained_lora :
855
+ replacements .update (
856
+ {
857
+ "base_layer.weight" : "weight" ,
858
+ "base_layer.input_scale" : "input_scale" ,
859
+ "base_layer.weight_scale" : "weight_scale" ,
860
+ }
861
+ )
862
+
843
863
post_state_dict = {}
844
864
845
865
for key , value in state_dict .items ():
@@ -902,10 +922,10 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
902
922
keys_to_delete .append (key )
903
923
904
924
# remove LoRA adapters from state dict
905
- for key , value in post_state_dict . items () :
906
- if "lora" in key and key not in keys_to_delete :
907
- keys_to_delete . append ( key )
908
-
925
+ if is_modelopt_trained_lora :
926
+ for key , value in post_state_dict . items () :
927
+ if "lora" in key and key not in keys_to_delete :
928
+ keys_to_delete . append ( key )
909
929
# Check for tied weights and remove duplicates
910
930
seen_tensors = {}
911
931
0 commit comments