@@ -270,23 +270,28 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270
270
QUANTIZATION_NVFP4_AWQ ,
271
271
QUANTIZATION_W4A8_NVFP4_FP8 ,
272
272
]:
273
+ # If scale is already registered, indicates weights are already compressed.
274
+ # We convert to modelopt scale if necessary and return
273
275
if hasattr (weight_quantizer , "_scale" ):
274
276
return NVFP4QTensor .get_modelopt_weights_scaling_factor (
275
277
weight_quantizer ._scale , weight .metadata ["shape" ]
276
278
)
277
-
278
- return NVFP4QTensor .get_weights_scaling_factor (
279
- weight ,
280
- weight_quantizer .block_sizes [- 1 ],
281
- NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer ).to (
282
- weight .device
283
- ),
284
- )[0 ]
279
+ else :
280
+ return NVFP4QTensor .get_weights_scaling_factor (
281
+ weight ,
282
+ weight_quantizer .block_sizes [- 1 ],
283
+ NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer ).to (
284
+ weight .device
285
+ ),
286
+ )[0 ]
285
287
286
288
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8 , QUANTIZATION_MXFP4 ]:
287
- return MXFP4QTensor .quantize (weight , block_size = weight_quantizer .block_sizes [- 1 ])[
288
- 1
289
- ].reshape (* weight .shape [:- 1 ], - 1 )
289
+ if hasattr (weight_quantizer , "_scale" ):
290
+ return weight_quantizer ._scale .reshape (* weight .shape [:- 1 ], - 1 )
291
+ else :
292
+ return MXFP4QTensor .quantize (weight , block_size = weight_quantizer .block_sizes [- 1 ])[
293
+ 1
294
+ ].reshape (* weight .shape [:- 1 ], - 1 )
290
295
return get_scaling_factor (weight_quantizer )
291
296
292
297
@@ -302,7 +307,10 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
302
307
QUANTIZATION_NVFP4_AWQ ,
303
308
QUANTIZATION_W4A8_NVFP4_FP8 ,
304
309
]:
305
- return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
310
+ if hasattr (weight_quantizer , "_double_scale" ):
311
+ return weight_quantizer ._double_scale
312
+ else :
313
+ return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
306
314
307
315
# SequentialQuantizer is required
308
316
if not isinstance (weight_quantizer , SequentialQuantizer ) or not weight_quantizer [- 1 ].is_enabled :
@@ -824,7 +832,12 @@ def from_quantized_weight(
824
832
raise NotImplementedError (f"quantization format { quantization } not supported" )
825
833
826
834
827
- def postprocess_state_dict (state_dict : dict , maxbound : float , quantization : str | None ) -> dict :
835
+ def postprocess_state_dict (
836
+ state_dict : dict ,
837
+ maxbound : float ,
838
+ quantization : str | None ,
839
+ is_modelopt_trained_lora : bool = False ,
840
+ ) -> dict :
828
841
"""Filters out keys related to weight quantizers and updates KV cache related keys.
829
842
830
843
Args:
@@ -841,11 +854,18 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
841
854
"k_bmm_quantizer._bias_value" : "k_proj.k_bias" ,
842
855
"v_bmm_quantizer._bias_value" : "v_proj.v_bias" ,
843
856
"input_quantizer._pre_quant_scale" : "pre_quant_scale" ,
844
- "base_layer.weight" : "weight" ,
845
- "base_layer.input_scale" : "input_scale" ,
846
- "base_layer.weight_scale" : "weight_scale" ,
847
857
}
848
858
859
+ # For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
860
+ if is_modelopt_trained_lora :
861
+ replacements .update (
862
+ {
863
+ "base_layer.weight" : "weight" ,
864
+ "base_layer.input_scale" : "input_scale" ,
865
+ "base_layer.weight_scale" : "weight_scale" ,
866
+ }
867
+ )
868
+
849
869
post_state_dict = {}
850
870
851
871
for key , value in state_dict .items ():
@@ -908,10 +928,10 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
908
928
keys_to_delete .append (key )
909
929
910
930
# remove LoRA adapters from state dict
911
- for key , value in post_state_dict . items () :
912
- if "lora" in key and key not in keys_to_delete :
913
- keys_to_delete . append ( key )
914
-
931
+ if is_modelopt_trained_lora :
932
+ for key , value in post_state_dict . items () :
933
+ if "lora" in key and key not in keys_to_delete :
934
+ keys_to_delete . append ( key )
915
935
# Check for tied weights and remove duplicates
916
936
seen_tensors = {}
917
937
0 commit comments