@@ -94,27 +94,6 @@ def get_weights_scaling_factor_2(cls, input: torch.Tensor):
94
94
"""Returns per tensor weight scaling factor."""
95
95
return reduce_amax (input ).float () / (6.0 * 448.0 )
96
96
97
- @classmethod
98
- def get_modelopt_weights_scaling_factor (cls , weight_scaling_factor : torch .Tensor , weight_shape ):
99
- """Returns the modelopt weights scaling factor if the quantization is done by trtllm."""
100
- if weight_scaling_factor .dtype == torch .float8_e4m3fn :
101
- return weight_scaling_factor
102
-
103
- if weight_scaling_factor .dtype == torch .uint8 and weight_scaling_factor .ndim == 1 :
104
- # If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
105
- try :
106
- from tensorrt_llm ._torch .auto_deploy .utils .quantization_utils import (
107
- cutlass_fp4_scale_to_modelopt_fp4_scale ,
108
- )
109
-
110
- return cutlass_fp4_scale_to_modelopt_fp4_scale (
111
- weight_scaling_factor , weight_shape [- 2 :]
112
- )
113
- except ImportError as e :
114
- raise ImportError (
115
- "This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
116
- ) from e
117
-
118
97
@classmethod
119
98
def get_activation_scaling_factor (cls , quantizer ):
120
99
"""Returns the activation scaling factor for export."""
@@ -270,9 +249,20 @@ def _unpack_tensor(input: torch.Tensor):
270
249
return unpacked .reshape (unpacked_shape )
271
250
272
251
# Get scales from kwargs
273
- kwarg ["scale" ] = self .get_modelopt_weights_scaling_factor (
274
- kwarg ["scale" ], self .metadata ["shape" ]
275
- )
252
+ if kwarg ["scale" ].dtype == torch .uint8 and kwarg ["scale" ].ndim == 1 :
253
+ # If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
254
+ try :
255
+ from tensorrt_llm ._torch .auto_deploy .utils .quantization_utils import (
256
+ cutlass_fp4_scale_to_modelopt_fp4_scale ,
257
+ )
258
+
259
+ kwarg ["scale" ] = cutlass_fp4_scale_to_modelopt_fp4_scale (
260
+ kwarg ["scale" ], self .metadata ["shape" ][- 2 :]
261
+ )
262
+ except ImportError as e :
263
+ raise ImportError (
264
+ "This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
265
+ ) from e
276
266
277
267
if fast :
278
268
from ..triton .fp4_kernel import fp4_dequantize
0 commit comments