1
- from typing import Optional , Tuple , Union
1
+ from typing import Optional , Union
2
2
3
3
import numpy as np
4
4
import tensorrt as trt
@@ -92,90 +92,64 @@ def nvfp4_quantize(
92
92
f"lan added nvfp4_quantize entered: { target = } { source_ir = } { name = } { input_tensor .shape = } { input_tensor .dtype = } { block_size = } { amax = } { num_bits = } { exponent_bits = } { scale_num_bits = } { scale_exponent_bits = } "
93
93
)
94
94
with unset_fake_temporarily ():
95
- if not isinstance (input_tensor , TRTTensor ):
96
- input_tensor = get_trt_tensor (ctx , input_tensor , name + "_input" )
97
95
if input_tensor .dtype not in (
98
96
trt .float32 ,
99
97
trt .float16 ,
100
98
trt .bfloat16 ,
99
+ torch .float32 ,
100
+ torch .float16 ,
101
+ torch .bfloat16 ,
101
102
):
102
103
raise ValueError (
103
104
f"dynamic_block_quantize converter received an input of { input_tensor .dtype } type. Supported types: float32 | float16 | bfloat16"
104
105
)
105
- # TODO: ADD PADDING IF
106
-
107
- # calculate global scale (the global per-tensor scaling factor, should only contain 1 element)
108
- amax = to_torch (
109
- amax , None
110
- ) # amax is calculated from input_tensor.abs().amax().float()
111
- global_scale = torch .divide (amax , 6 )
112
- global_scale = get_trt_tensor (ctx , global_scale , name + "_global_scale" )
113
-
114
- if ".weight_quantizer" in name :
115
- # calculate block scaling factor of weights
116
- [n , k ] = input_tensor .shape [- 2 :]
117
- assert block_size != 0 , "block_size must be non-zero"
118
- assert k % block_size == 0 , "k must be a multiple of block_size"
119
- reshaped_input_tensor = input_tensor .reshape (
120
- tuple (input_tensor .shape [:- 2 ]) + (n , k // block_size , block_size )
106
+ if len (input_tensor .shape ) not in (2 , 3 ):
107
+ raise ValueError (
108
+ f"dynamic_block_quantize converter received an input of { input_tensor .shape } shape. Supported shapes: 2D or 3D"
121
109
)
122
- per_block_amax = reshaped_input_tensor .abs ().amax (dim = - 1 ).float ()
123
- per_block_scale = torch .divide (per_block_amax , 6 )
110
+ axis = len (input_tensor .shape ) - 1
124
111
125
- per_block_scale = get_trt_tensor (
126
- ctx , per_block_scale , name + "_per_block_scale"
127
- )
112
+ # TODO: ADD PADDING IF NEEDED
113
+ # TODO: ADD DYNAMIC SHAPE SUPPORT
128
114
129
- # static double quantization is used for weights
130
- quantized_data_in_fp4 , quantized_block_scale_in_fp8 = (
131
- _static_double_quantize (
132
- ctx ,
133
- target ,
134
- source_ir ,
135
- name ,
136
- input_tensor ,
137
- per_block_scale ,
138
- global_scale ,
139
- )
140
- )
141
- output = _block_double_dequantize (
115
+ global_scale = _calculate_global_scale (ctx , name , amax )
116
+
117
+ if ".weight_quantizer" in name :
118
+ block_scale = _calculate_block_scale (
142
119
ctx ,
143
- target ,
144
- source_ir ,
145
120
name ,
146
- quantized_data_in_fp4 ,
147
- quantized_block_scale_in_fp8 ,
148
- global_scale ,
121
+ input_tensor ,
122
+ block_size ,
149
123
)
150
- elif ".input_quantizer" in name :
151
- # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8
152
- quantized_data_in_fp4 , quantized_scale_in_fp8 = _dynamic_quantize (
124
+ input_tensor = get_trt_tensor (ctx , input_tensor , name + "_input" )
125
+ output = _static_double_quantize (
153
126
ctx ,
154
127
target ,
155
128
source_ir ,
156
129
name ,
157
130
input_tensor ,
131
+ block_scale ,
158
132
global_scale ,
159
133
)
160
- # Add double DQ node
161
- output = _block_double_dequantize (
134
+ elif ".input_quantizer" in name :
135
+ # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8
136
+ output = _dynamic_double_quantize (
162
137
ctx ,
163
138
target ,
164
139
source_ir ,
165
140
name ,
166
- quantized_data_in_fp4 ,
167
- quantized_scale_in_fp8 ,
141
+ input_tensor ,
168
142
global_scale ,
169
- input_tensor .dtype ,
170
143
)
144
+
171
145
else :
172
146
raise ValueError (
173
- f"dynamic_block_quantize converter received an input of { name } name . Supported names : weight_quantizer | input_quantizer"
147
+ f"quantizer received an input of { name } . Supported values : weight_quantizer | input_quantizer"
174
148
)
175
149
return output
176
150
177
151
178
- def _dynamic_quantize (
152
+ def _dynamic_double_quantize (
179
153
ctx : ConversionContext ,
180
154
target : Target ,
181
155
source_ir : Optional [SourceIR ],
@@ -186,7 +160,7 @@ def _dynamic_quantize(
186
160
block_size : int = 16 ,
187
161
output_type : trt .DataType = trt .DataType .FP4 ,
188
162
scale_type : trt .DataType = trt .DataType .FP8 ,
189
- ) -> Tuple [ TRTTensor , TRTTensor ] :
163
+ ) -> TRTTensor :
190
164
"""
191
165
quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8
192
166
Parameters:
@@ -202,20 +176,13 @@ def _dynamic_quantize(
202
176
The axis to quantize. Default is -1 (the last axis).
203
177
block_size : int
204
178
The block size for quantization. Default is 16.
205
- data_qtype : trt.DataType
179
+ output_type : trt.DataType
206
180
The data type for quantized data. Default is FP4.
207
- scale_qtype : trt.DataType
181
+ scale_type : trt.DataType
208
182
The data type for block scale. Default is FP8.
209
- Returns:
210
- A tuple of two tensors: quantized data tensor in fp4 and quantized scale tensor in fp8.
183
+
211
184
"""
212
- if len (input_tensor .shape ) not in (2 , 3 ):
213
- raise ValueError (
214
- f"dynamic_block_quantize converter received an input of { input_tensor .shape } shape. Supported shapes: 2D or 3D"
215
- )
216
- if axis < 0 :
217
- axis = len (input_tensor .shape ) + axis
218
- # Add DYQ node
185
+ # dynamic quantize input tensor to fp4
219
186
dynamic_quantize_layer = ctx .net .add_dynamic_quantize (
220
187
input_tensor ,
221
188
axis ,
@@ -229,51 +196,23 @@ def _dynamic_quantize(
229
196
)
230
197
quantized_data_in_fp4 = dynamic_quantize_layer .get_output (0 )
231
198
quantized_scale_in_fp8 = dynamic_quantize_layer .get_output (1 )
232
- return quantized_data_in_fp4 , quantized_scale_in_fp8
233
199
234
-
235
- def _block_double_dequantize (
236
- ctx : ConversionContext ,
237
- target : Target ,
238
- source_ir : Optional [SourceIR ],
239
- name : str ,
240
- input_tensor : TRTTensor ,
241
- scale : TRTTensor ,
242
- global_scale : TRTTensor ,
243
- dtype : trt .DataType = trt .DataType .FLOAT ,
244
- ) -> TRTTensor :
245
- """
246
- dequantize input_tensor from fp4 to dtype(default is float32)
247
- Parameters:
248
- ctx: ConversionContext,
249
- target: Target,
250
- source_ir: Optional[SourceIR]
251
- name: str
252
- input_tensor : Tensor (On GPU)
253
- The input tensor.
254
- scale : Tensor (On GPU)
255
- The block scale tensor.
256
- global_scale : Tensor (On GPU)
257
- The global per-tensor scaling factor. It should contain only 1 element.
258
- dtype : trt.DataType | str
259
- The data type for dequantized data. Default is float32.
260
- Returns:
261
- The dequantized tensor.
262
- """
263
- # dequantize scale from fp8 to dtype(default is float32)
264
- dequantize_scale_layer = ctx .net .add_dequantize (scale , global_scale , dtype )
200
+ # dequantize scale from fp8 to orignal dtype(default is float32)
201
+ dequantize_scale_layer = ctx .net .add_dequantize (
202
+ quantized_scale_in_fp8 , global_scale , input_tensor .dtype
203
+ )
265
204
set_layer_name (
266
205
dequantize_scale_layer , target , name + "_dequantize_scale" , source_ir
267
206
)
268
207
dequantized_scale = dequantize_scale_layer .get_output (0 )
269
208
270
- # dequantize input_tensor from fp4 to dtype(default is float32)
209
+ # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32)
271
210
dequantize_data_layer = ctx .net .add_dequantize (
272
- input_tensor , dequantized_scale , dtype
211
+ quantized_data_in_fp4 , dequantized_scale , input_tensor . dtype
273
212
)
274
213
set_layer_name (dequantize_data_layer , target , name + "_dequantize_data" , source_ir )
275
- dq_output = dequantize_data_layer .get_output (0 )
276
- return dq_output
214
+ dequantized_data = dequantize_data_layer .get_output (0 )
215
+ return dequantized_data
277
216
278
217
279
218
def _static_double_quantize (
@@ -282,9 +221,9 @@ def _static_double_quantize(
282
221
source_ir : Optional [SourceIR ],
283
222
name : str ,
284
223
input_tensor : TRTTensor ,
285
- per_block_scale : TRTTensor ,
224
+ block_scale : TRTTensor ,
286
225
global_scale : TRTTensor ,
287
- ) -> Tuple [ TRTTensor , TRTTensor ] :
226
+ ) -> TRTTensor :
288
227
"""
289
228
Parameters:
290
229
ctx: ConversionContext,
@@ -293,41 +232,84 @@ def _static_double_quantize(
293
232
name: str,
294
233
input_tensor : Tensor (On GPU)
295
234
The input tensor.
296
- per_block_scale : Tensor (On GPU)
235
+ block_scale : Tensor (On GPU)
297
236
The per-block scaling factor.
298
237
global_scale : Tensor (On GPU)
299
238
The global per-tensor scaling factor. It should contain only 1 element.
300
239
Returns:
301
240
A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8
302
241
"""
303
-
304
- block_scale_quantize_layer = ctx .net .add_quantize (per_block_scale , global_scale )
242
+ # quantize block scale to fp8
243
+ block_scale_quantize_layer = ctx .net .add_quantize (block_scale , global_scale )
305
244
set_layer_name (
306
245
block_scale_quantize_layer ,
307
246
target ,
308
- name + "_per_block_scale_quantize " ,
247
+ name + "_block_scale_quantize " ,
309
248
source_ir ,
310
249
)
311
250
block_scale_quantize_layer .set_output_type (0 , trt .DataType .FP8 )
312
251
quantized_block_scale_in_fp8 = block_scale_quantize_layer .get_output (0 )
313
252
253
+ # dequantize block scale from fp8 to original dtype(default is float32)
314
254
dequantize_block_scale_layer = ctx .net .add_dequantize (
315
255
quantized_block_scale_in_fp8 ,
316
256
global_scale ,
317
- per_block_scale .dtype ,
257
+ block_scale .dtype ,
318
258
)
319
259
set_layer_name (
320
260
dequantize_block_scale_layer ,
321
261
target ,
322
262
name + "_dequantize_block_scale" ,
323
263
source_ir ,
324
264
)
325
- dequantize_block_scale_layer .precision = trt .DataType .FP8
326
265
dequantized_block_scale = dequantize_block_scale_layer .get_output (0 )
327
266
267
+ # quantize input tensor to fp4
328
268
data_quantize_layer = ctx .net .add_quantize (input_tensor , dequantized_block_scale )
329
269
set_layer_name (data_quantize_layer , target , name + "_data_quantize" , source_ir )
330
270
data_quantize_layer .set_output_type (0 , trt .DataType .FP4 )
331
271
quantized_data_in_fp4 = data_quantize_layer .get_output (0 )
332
272
333
- return quantized_data_in_fp4 , quantized_block_scale_in_fp8
273
+ # dequantize input tensor from fp4 to originaldtype(default is float32)
274
+ dequantize_data_layer = ctx .net .add_dequantize (
275
+ quantized_data_in_fp4 ,
276
+ dequantized_block_scale ,
277
+ input_tensor .dtype ,
278
+ )
279
+ set_layer_name (dequantize_data_layer , target , name + "_dequantize_data" , source_ir )
280
+ dequantized_data = dequantize_data_layer .get_output (0 )
281
+ return dequantized_data
282
+
283
+
284
+ def _calculate_global_scale (
285
+ ctx : ConversionContext ,
286
+ name : str ,
287
+ amax : TRTTensor ,
288
+ ) -> TRTTensor :
289
+ # calculate global scale (the global per-tensor scaling factor, should only contain 1 element)
290
+ amax = to_torch (
291
+ amax , None
292
+ ) # amax is calculated from input_tensor.abs().amax().float()
293
+ global_scale = torch .divide (amax , 6 * 448 )
294
+ global_scale = get_trt_tensor (ctx , global_scale , name + "_global_scale" )
295
+ return global_scale
296
+
297
+
298
+ def _calculate_block_scale (
299
+ ctx : ConversionContext ,
300
+ name : str ,
301
+ input_tensor : TRTTensor ,
302
+ block_size : int ,
303
+ ) -> TRTTensor :
304
+
305
+ [n , k ] = input_tensor .shape [- 2 :]
306
+ assert block_size != 0 , "block_size must be non-zero"
307
+ assert k % block_size == 0 , "k must be a multiple of block_size"
308
+ reshaped_input_tensor = input_tensor .reshape (
309
+ tuple (input_tensor .shape [:- 2 ]) + (n , k // block_size , block_size )
310
+ )
311
+ block_amax = reshaped_input_tensor .abs ().amax (dim = - 1 ).float ()
312
+ block_scale = torch .divide (block_amax , 6 )
313
+
314
+ block_scale = get_trt_tensor (ctx , block_scale , name + "_block_scale" )
315
+ return block_scale
0 commit comments