Skip to content

Commit 9b41093

Browse files
committed
Add fix and unittest
Signed-off-by: gcunhase <[email protected]>
1 parent 08fb23f commit 9b41093

File tree

3 files changed

+219
-4
lines changed

3 files changed

+219
-4
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,8 @@ def build_non_residual_input_map(
526526
# Input in the longest path to LCA is the non-residual input
527527
lca, d1, d2 = find_lowest_common_ancestor(input1_producer, input2_producer)
528528

529-
# Generally if both the inputs have a backbone then both backbones are of the same type
530529
if backbone1 and backbone2:
531-
if backbone1 == backbone2 or backbone1.op != backbone2.op:
530+
if backbone1 == backbone2:
532531
non_residual_inputs[node.name] = None
533532
continue
534533

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,185 @@ def build_conv_concat_model():
372372
onnx.checker.check_model(model_inferred)
373373

374374
return model_inferred
375+
376+
377+
def build_convtranspose_conv_residual_model():
378+
# Define your model inputs and outputs
379+
input_names = ["input_0"]
380+
output_names = ["output_0"]
381+
input_shapes = [(2, 39, 96, 192)]
382+
output_shapes = [(2, 32, 192, 384)]
383+
384+
inputs = [
385+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
386+
for input_name, input_shape in zip(input_names, input_shapes)
387+
]
388+
outputs = [
389+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
390+
for output_name, output_shape in zip(output_names, output_shapes)
391+
]
392+
393+
# Create the ONNX graph with the nodes
394+
nodes = [
395+
helper.make_node(
396+
op_type="ConvTranspose",
397+
inputs=["input_0", "weights_1", "bias_1"],
398+
outputs=["convtranspose1_convtranspose/ConvTranspose:0"],
399+
name="convtranspose1_convtranspose/ConvTranspose",
400+
dilations=[1, 1],
401+
group=1,
402+
kernel_shape=[2, 2],
403+
pads=[0, 0, 0, 0],
404+
strides=[2, 2],
405+
),
406+
helper.make_node(
407+
op_type="Relu",
408+
inputs=["convtranspose1_convtranspose/ConvTranspose:0"],
409+
outputs=["relu1_relu/Relu:0"],
410+
name="relu1_relu/Relu",
411+
),
412+
helper.make_node(
413+
op_type="Conv",
414+
inputs=["relu1_relu/Relu:0", "weights_2"],
415+
outputs=["conv2_conv/Conv2D:0"],
416+
name="conv2_conv/Conv2D",
417+
dilations=[1, 1],
418+
group=1,
419+
kernel_shape=[3, 3],
420+
pads=[1, 1, 1, 1],
421+
strides=[1, 1],
422+
),
423+
helper.make_node(
424+
op_type="BatchNormalization",
425+
inputs=["conv2_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
426+
outputs=["bn1_batchnorm/BatchNormalization:0"],
427+
name="bn1_batchnorm/BatchNormalization",
428+
),
429+
helper.make_node(
430+
op_type="Relu",
431+
inputs=["bn1_batchnorm/BatchNormalization:0"],
432+
outputs=["relu2_relu/Relu:0"],
433+
name="relu2_relu/Relu",
434+
),
435+
helper.make_node(
436+
op_type="Conv",
437+
inputs=["relu2_relu/Relu:0", "weights_3"],
438+
outputs=["conv3_conv/Conv2D:0"],
439+
name="conv3_conv/Conv2D",
440+
dilations=[1, 1],
441+
group=1,
442+
kernel_shape=[3, 3],
443+
pads=[1, 1, 1, 1],
444+
strides=[1, 1],
445+
),
446+
helper.make_node(
447+
op_type="BatchNormalization",
448+
inputs=["conv3_conv/Conv2D:0", "bn2_scale", "bn2_bias", "bn2_mean", "bn2_var"],
449+
outputs=["bn2_batchnorm/BatchNormalization:0"],
450+
name="bn2_batchnorm/BatchNormalization",
451+
),
452+
helper.make_node(
453+
op_type="Add",
454+
inputs=["relu1_relu/Relu:0", "bn2_batchnorm/BatchNormalization:0"],
455+
outputs=["add1_add/Add:0"],
456+
name="add1_add/Add",
457+
),
458+
helper.make_node(
459+
op_type="Relu",
460+
inputs=["add1_add/Add:0"],
461+
outputs=["output_0"],
462+
name="relu3_relu/Relu",
463+
),
464+
]
465+
466+
# Create the ONNX initializers
467+
initializers = [
468+
helper.make_tensor(
469+
name="weights_1",
470+
data_type=onnx.TensorProto.FLOAT,
471+
dims=(39, 32, 2, 2),
472+
vals=np.random.uniform(low=0.5, high=1.0, size=39 * 32 * 2 * 2),
473+
),
474+
helper.make_tensor(
475+
name="bias_1",
476+
data_type=onnx.TensorProto.FLOAT,
477+
dims=(32,),
478+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
479+
),
480+
helper.make_tensor(
481+
name="weights_2",
482+
data_type=onnx.TensorProto.FLOAT,
483+
dims=(32, 32, 3, 3),
484+
vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3),
485+
),
486+
helper.make_tensor(
487+
name="bn1_scale",
488+
data_type=onnx.TensorProto.FLOAT,
489+
dims=(32,),
490+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
491+
),
492+
helper.make_tensor(
493+
name="bn1_bias",
494+
data_type=onnx.TensorProto.FLOAT,
495+
dims=(32,),
496+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
497+
),
498+
helper.make_tensor(
499+
name="bn1_mean",
500+
data_type=onnx.TensorProto.FLOAT,
501+
dims=(32,),
502+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
503+
),
504+
helper.make_tensor(
505+
name="bn1_var",
506+
data_type=onnx.TensorProto.FLOAT,
507+
dims=(32,),
508+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
509+
),
510+
helper.make_tensor(
511+
name="weights_3",
512+
data_type=onnx.TensorProto.FLOAT,
513+
dims=(32, 32, 3, 3),
514+
vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3),
515+
),
516+
helper.make_tensor(
517+
name="bn2_scale",
518+
data_type=onnx.TensorProto.FLOAT,
519+
dims=(32,),
520+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
521+
),
522+
helper.make_tensor(
523+
name="bn2_bias",
524+
data_type=onnx.TensorProto.FLOAT,
525+
dims=(32,),
526+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
527+
),
528+
helper.make_tensor(
529+
name="bn2_mean",
530+
data_type=onnx.TensorProto.FLOAT,
531+
dims=(32,),
532+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
533+
),
534+
helper.make_tensor(
535+
name="bn2_var",
536+
data_type=onnx.TensorProto.FLOAT,
537+
dims=(32,),
538+
vals=np.random.uniform(low=0.5, high=1.0, size=32),
539+
),
540+
]
541+
542+
# Create the ONNX graph with the nodes and initializers
543+
graph = helper.make_graph(
544+
nodes, "convtranspose_conv_residual", inputs, outputs, initializer=initializers
545+
)
546+
547+
# Create the ONNX model
548+
model = helper.make_model(graph)
549+
model.opset_import[0].version = 13
550+
model.ir_version = 10
551+
552+
# Check the ONNX model
553+
model_inferred = onnx.shape_inference.infer_shapes(model)
554+
onnx.checker.check_model(model_inferred)
555+
556+
return model_inferred

tests/unit/onnx/test_quantize_int8.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@
1919
import onnx_graphsurgeon as gs
2020
import pytest
2121
import torch
22-
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx
22+
from _test_utils.onnx_quantization.lib_test_models import (
23+
SimpleMLP,
24+
build_convtranspose_conv_residual_model,
25+
export_as_onnx,
26+
)
2327

2428
import modelopt.onnx.quantization as moq
29+
from modelopt.onnx.utils import save_onnx
2530

2631

2732
def _assert_nodes_are_quantized(nodes):
@@ -52,6 +57,35 @@ def test_int8(tmp_path, high_precision_dtype):
5257
# Load the output model and check QDQ node placements
5358
graph = gs.import_onnx(onnx.load(output_onnx_path))
5459

55-
# Check that all MatMul nodes are quantized
60+
# Check that all MatMul nodes are quantized
5661
mm_nodes = [n for n in graph.nodes if n.op == "MatMul"]
5762
assert _assert_nodes_are_quantized(mm_nodes)
63+
64+
65+
def test_convtranspose_conv_residual_int8(tmp_path="./"):
66+
onnx_model = build_convtranspose_conv_residual_model()
67+
onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx")
68+
save_onnx(onnx_model, onnx_path)
69+
70+
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
71+
72+
# Output model should be produced in the same tmp_path
73+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
74+
75+
# Check that quantized explicit model is generated
76+
assert os.path.isfile(output_onnx_path)
77+
78+
# Load the output model and check QDQ node placements
79+
graph = gs.import_onnx(onnx.load(output_onnx_path))
80+
81+
# Check that Conv and ConvTransposed are quantized
82+
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
83+
assert _assert_nodes_are_quantized(conv_nodes)
84+
85+
# Check that only 1 input of Add is quantized
86+
add_nodes = [n for n in graph.nodes if n.op == "Add"]
87+
for node in add_nodes:
88+
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
89+
assert len(quantized_inputs) == 1, (
90+
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
91+
)

0 commit comments

Comments
 (0)