Skip to content

Commit 559ada5

Browse files
committed
restructure the dynamic double quantize and static double quantize code
1 parent 5198f9a commit 559ada5

File tree

2 files changed

+94
-111
lines changed

2 files changed

+94
-111
lines changed
Lines changed: 91 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Optional, Union
22

33
import numpy as np
44
import tensorrt as trt
@@ -92,90 +92,64 @@ def nvfp4_quantize(
9292
f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}"
9393
)
9494
with unset_fake_temporarily():
95-
if not isinstance(input_tensor, TRTTensor):
96-
input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input")
9795
if input_tensor.dtype not in (
9896
trt.float32,
9997
trt.float16,
10098
trt.bfloat16,
99+
torch.float32,
100+
torch.float16,
101+
torch.bfloat16,
101102
):
102103
raise ValueError(
103104
f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16"
104105
)
105-
# TODO: ADD PADDING IF
106-
107-
# calculate global scale (the global per-tensor scaling factor, should only contain 1 element)
108-
amax = to_torch(
109-
amax, None
110-
) # amax is calculated from input_tensor.abs().amax().float()
111-
global_scale = torch.divide(amax, 6)
112-
global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
113-
114-
if ".weight_quantizer" in name:
115-
# calculate block scaling factor of weights
116-
[n, k] = input_tensor.shape[-2:]
117-
assert block_size != 0, "block_size must be non-zero"
118-
assert k % block_size == 0, "k must be a multiple of block_size"
119-
reshaped_input_tensor = input_tensor.reshape(
120-
tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size)
106+
if len(input_tensor.shape) not in (2, 3):
107+
raise ValueError(
108+
f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D"
121109
)
122-
per_block_amax = reshaped_input_tensor.abs().amax(dim=-1).float()
123-
per_block_scale = torch.divide(per_block_amax, 6)
110+
axis = len(input_tensor.shape) - 1
124111

125-
per_block_scale = get_trt_tensor(
126-
ctx, per_block_scale, name + "_per_block_scale"
127-
)
112+
# TODO: ADD PADDING IF NEEDED
113+
# TODO: ADD DYNAMIC SHAPE SUPPORT
128114

129-
# static double quantization is used for weights
130-
quantized_data_in_fp4, quantized_block_scale_in_fp8 = (
131-
_static_double_quantize(
132-
ctx,
133-
target,
134-
source_ir,
135-
name,
136-
input_tensor,
137-
per_block_scale,
138-
global_scale,
139-
)
140-
)
141-
output = _block_double_dequantize(
115+
global_scale = _calculate_global_scale(ctx, name, amax)
116+
117+
if ".weight_quantizer" in name:
118+
block_scale = _calculate_block_scale(
142119
ctx,
143-
target,
144-
source_ir,
145120
name,
146-
quantized_data_in_fp4,
147-
quantized_block_scale_in_fp8,
148-
global_scale,
121+
input_tensor,
122+
block_size,
149123
)
150-
elif ".input_quantizer" in name:
151-
# quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8
152-
quantized_data_in_fp4, quantized_scale_in_fp8 = _dynamic_quantize(
124+
input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input")
125+
output = _static_double_quantize(
153126
ctx,
154127
target,
155128
source_ir,
156129
name,
157130
input_tensor,
131+
block_scale,
158132
global_scale,
159133
)
160-
# Add double DQ node
161-
output = _block_double_dequantize(
134+
elif ".input_quantizer" in name:
135+
# quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8
136+
output = _dynamic_double_quantize(
162137
ctx,
163138
target,
164139
source_ir,
165140
name,
166-
quantized_data_in_fp4,
167-
quantized_scale_in_fp8,
141+
input_tensor,
168142
global_scale,
169-
input_tensor.dtype,
170143
)
144+
171145
else:
172146
raise ValueError(
173-
f"dynamic_block_quantize converter received an input of {name} name. Supported names: weight_quantizer | input_quantizer"
147+
f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer"
174148
)
175149
return output
176150

177151

178-
def _dynamic_quantize(
152+
def _dynamic_double_quantize(
179153
ctx: ConversionContext,
180154
target: Target,
181155
source_ir: Optional[SourceIR],
@@ -186,7 +160,7 @@ def _dynamic_quantize(
186160
block_size: int = 16,
187161
output_type: trt.DataType = trt.DataType.FP4,
188162
scale_type: trt.DataType = trt.DataType.FP8,
189-
) -> Tuple[TRTTensor, TRTTensor]:
163+
) -> TRTTensor:
190164
"""
191165
quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8
192166
Parameters:
@@ -202,20 +176,13 @@ def _dynamic_quantize(
202176
The axis to quantize. Default is -1 (the last axis).
203177
block_size : int
204178
The block size for quantization. Default is 16.
205-
data_qtype : trt.DataType
179+
output_type : trt.DataType
206180
The data type for quantized data. Default is FP4.
207-
scale_qtype : trt.DataType
181+
scale_type : trt.DataType
208182
The data type for block scale. Default is FP8.
209-
Returns:
210-
A tuple of two tensors: quantized data tensor in fp4 and quantized scale tensor in fp8.
183+
211184
"""
212-
if len(input_tensor.shape) not in (2, 3):
213-
raise ValueError(
214-
f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D"
215-
)
216-
if axis < 0:
217-
axis = len(input_tensor.shape) + axis
218-
# Add DYQ node
185+
# dynamic quantize input tensor to fp4
219186
dynamic_quantize_layer = ctx.net.add_dynamic_quantize(
220187
input_tensor,
221188
axis,
@@ -229,51 +196,23 @@ def _dynamic_quantize(
229196
)
230197
quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0)
231198
quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1)
232-
return quantized_data_in_fp4, quantized_scale_in_fp8
233199

234-
235-
def _block_double_dequantize(
236-
ctx: ConversionContext,
237-
target: Target,
238-
source_ir: Optional[SourceIR],
239-
name: str,
240-
input_tensor: TRTTensor,
241-
scale: TRTTensor,
242-
global_scale: TRTTensor,
243-
dtype: trt.DataType = trt.DataType.FLOAT,
244-
) -> TRTTensor:
245-
"""
246-
dequantize input_tensor from fp4 to dtype(default is float32)
247-
Parameters:
248-
ctx: ConversionContext,
249-
target: Target,
250-
source_ir: Optional[SourceIR]
251-
name: str
252-
input_tensor : Tensor (On GPU)
253-
The input tensor.
254-
scale : Tensor (On GPU)
255-
The block scale tensor.
256-
global_scale : Tensor (On GPU)
257-
The global per-tensor scaling factor. It should contain only 1 element.
258-
dtype : trt.DataType | str
259-
The data type for dequantized data. Default is float32.
260-
Returns:
261-
The dequantized tensor.
262-
"""
263-
# dequantize scale from fp8 to dtype(default is float32)
264-
dequantize_scale_layer = ctx.net.add_dequantize(scale, global_scale, dtype)
200+
# dequantize scale from fp8 to orignal dtype(default is float32)
201+
dequantize_scale_layer = ctx.net.add_dequantize(
202+
quantized_scale_in_fp8, global_scale, input_tensor.dtype
203+
)
265204
set_layer_name(
266205
dequantize_scale_layer, target, name + "_dequantize_scale", source_ir
267206
)
268207
dequantized_scale = dequantize_scale_layer.get_output(0)
269208

270-
# dequantize input_tensor from fp4 to dtype(default is float32)
209+
# dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32)
271210
dequantize_data_layer = ctx.net.add_dequantize(
272-
input_tensor, dequantized_scale, dtype
211+
quantized_data_in_fp4, dequantized_scale, input_tensor.dtype
273212
)
274213
set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir)
275-
dq_output = dequantize_data_layer.get_output(0)
276-
return dq_output
214+
dequantized_data = dequantize_data_layer.get_output(0)
215+
return dequantized_data
277216

278217

279218
def _static_double_quantize(
@@ -282,9 +221,9 @@ def _static_double_quantize(
282221
source_ir: Optional[SourceIR],
283222
name: str,
284223
input_tensor: TRTTensor,
285-
per_block_scale: TRTTensor,
224+
block_scale: TRTTensor,
286225
global_scale: TRTTensor,
287-
) -> Tuple[TRTTensor, TRTTensor]:
226+
) -> TRTTensor:
288227
"""
289228
Parameters:
290229
ctx: ConversionContext,
@@ -293,41 +232,84 @@ def _static_double_quantize(
293232
name: str,
294233
input_tensor : Tensor (On GPU)
295234
The input tensor.
296-
per_block_scale : Tensor (On GPU)
235+
block_scale : Tensor (On GPU)
297236
The per-block scaling factor.
298237
global_scale : Tensor (On GPU)
299238
The global per-tensor scaling factor. It should contain only 1 element.
300239
Returns:
301240
A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8
302241
"""
303-
304-
block_scale_quantize_layer = ctx.net.add_quantize(per_block_scale, global_scale)
242+
# quantize block scale to fp8
243+
block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale)
305244
set_layer_name(
306245
block_scale_quantize_layer,
307246
target,
308-
name + "_per_block_scale_quantize",
247+
name + "_block_scale_quantize",
309248
source_ir,
310249
)
311250
block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8)
312251
quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0)
313252

253+
# dequantize block scale from fp8 to original dtype(default is float32)
314254
dequantize_block_scale_layer = ctx.net.add_dequantize(
315255
quantized_block_scale_in_fp8,
316256
global_scale,
317-
per_block_scale.dtype,
257+
block_scale.dtype,
318258
)
319259
set_layer_name(
320260
dequantize_block_scale_layer,
321261
target,
322262
name + "_dequantize_block_scale",
323263
source_ir,
324264
)
325-
dequantize_block_scale_layer.precision = trt.DataType.FP8
326265
dequantized_block_scale = dequantize_block_scale_layer.get_output(0)
327266

267+
# quantize input tensor to fp4
328268
data_quantize_layer = ctx.net.add_quantize(input_tensor, dequantized_block_scale)
329269
set_layer_name(data_quantize_layer, target, name + "_data_quantize", source_ir)
330270
data_quantize_layer.set_output_type(0, trt.DataType.FP4)
331271
quantized_data_in_fp4 = data_quantize_layer.get_output(0)
332272

333-
return quantized_data_in_fp4, quantized_block_scale_in_fp8
273+
# dequantize input tensor from fp4 to originaldtype(default is float32)
274+
dequantize_data_layer = ctx.net.add_dequantize(
275+
quantized_data_in_fp4,
276+
dequantized_block_scale,
277+
input_tensor.dtype,
278+
)
279+
set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir)
280+
dequantized_data = dequantize_data_layer.get_output(0)
281+
return dequantized_data
282+
283+
284+
def _calculate_global_scale(
285+
ctx: ConversionContext,
286+
name: str,
287+
amax: TRTTensor,
288+
) -> TRTTensor:
289+
# calculate global scale (the global per-tensor scaling factor, should only contain 1 element)
290+
amax = to_torch(
291+
amax, None
292+
) # amax is calculated from input_tensor.abs().amax().float()
293+
global_scale = torch.divide(amax, 6 * 448)
294+
global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
295+
return global_scale
296+
297+
298+
def _calculate_block_scale(
299+
ctx: ConversionContext,
300+
name: str,
301+
input_tensor: TRTTensor,
302+
block_size: int,
303+
) -> TRTTensor:
304+
305+
[n, k] = input_tensor.shape[-2:]
306+
assert block_size != 0, "block_size must be non-zero"
307+
assert k % block_size == 0, "k must be a multiple of block_size"
308+
reshaped_input_tensor = input_tensor.reshape(
309+
tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size)
310+
)
311+
block_amax = reshaped_input_tensor.abs().amax(dim=-1).float()
312+
block_scale = torch.divide(block_amax, 6)
313+
314+
block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale")
315+
return block_scale

tests/py/dynamo/models/test_models_export.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_base_fp4(ir):
215215
class SimpleNetwork(torch.nn.Module):
216216
def __init__(self):
217217
super(SimpleNetwork, self).__init__()
218-
self.linear1 = torch.nn.Linear(in_features=16, out_features=5)
218+
self.linear1 = torch.nn.Linear(in_features=16, out_features=3)
219219

220220
def forward(self, x):
221221
x = self.linear1(x)
@@ -249,7 +249,7 @@ def calibrate_loop(model):
249249
outputs_trt = trt_model(input_tensor)
250250
print(f"lan added outputs_trt: {outputs_trt}")
251251
print(f"lan added output_pyt: {output_pyt}")
252-
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-1, atol=5e-1)
252+
assert torch.allclose(output_pyt, outputs_trt, rtol=4e-1, atol=4e-1)
253253

254254

255255
@unittest.skipIf(
@@ -284,6 +284,7 @@ def calibrate_loop(model):
284284
input_tensor = torch.randn(1, 10).cuda()
285285
model = SimpleNetwork().eval().cuda()
286286
quant_cfg = mtq.FP8_DEFAULT_CFG
287+
287288
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
288289
# model has FP8 qdq nodes at this point
289290
output_pyt = model(input_tensor)

0 commit comments

Comments
 (0)