Skip to content

Commit 47e3892

Browse files
committed
Added unittest
Signed-off-by: gcunhase <[email protected]>
1 parent 24f5a73 commit 47e3892

File tree

2 files changed

+162
-6
lines changed

2 files changed

+162
-6
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,128 @@ def build_convtranspose_conv_residual_model():
555555
onnx.checker.check_model(model_inferred)
556556

557557
return model_inferred
558+
559+
560+
def build_conv_act_pool_model():
561+
# Define your model inputs and outputs
562+
input_names = ["input_0"]
563+
output_names = ["output_0"]
564+
input_shapes = [(32, 64, 256, 256)]
565+
output_shapes = [(32, 128, 128, 128)]
566+
567+
inputs = [
568+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
569+
for input_name, input_shape in zip(input_names, input_shapes)
570+
]
571+
outputs = [
572+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
573+
for output_name, output_shape in zip(output_names, output_shapes)
574+
]
575+
576+
# Create the ONNX graph with the nodes
577+
nodes = [
578+
helper.make_node(
579+
op_type="Conv",
580+
inputs=["input_0", "weights_1", "bias_1"],
581+
outputs=["conv1_conv/Conv2D:0"],
582+
name="conv1_conv/Conv2D",
583+
dilations=[1, 1],
584+
group=1,
585+
kernel_shape=[3, 3],
586+
pads=[0, 0, 0, 0],
587+
strides=[1, 1],
588+
),
589+
helper.make_node(
590+
op_type="BatchNormalization",
591+
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
592+
outputs=["bn1_batchnorm/BatchNormalization:0"],
593+
name="bn1_batchnorm/BatchNormalization",
594+
),
595+
helper.make_node(
596+
op_type="Relu",
597+
inputs=["bn1_batchnorm/BatchNormalization:0"],
598+
outputs=["relu1_relu/Relu:0"],
599+
name="relu1_relu/Relu",
600+
),
601+
helper.make_node(
602+
op_type="MaxPool",
603+
inputs=["relu1_relu/Relu:0"],
604+
outputs=["maxpool1_maxpool/MaxPool2D:0"],
605+
name="maxpool1_maxpool/MaxPool2D",
606+
ceil_mode=False,
607+
kernel_shape=[3, 3],
608+
pads=[0, 0, 0, 0],
609+
strides=[2, 2],
610+
),
611+
helper.make_node(
612+
op_type="Conv",
613+
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
614+
outputs=["output_0"],
615+
name="conv2_conv/Conv2D",
616+
dilations=[1, 1],
617+
group=1,
618+
kernel_shape=[3, 3],
619+
pads=[0, 0, 0, 0],
620+
strides=[1, 1],
621+
),
622+
]
623+
624+
# Create the ONNX initializers
625+
initializers = [
626+
helper.make_tensor(
627+
name="weights_1",
628+
data_type=onnx.TensorProto.FLOAT,
629+
dims=(128, 64, 3, 3),
630+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3),
631+
),
632+
helper.make_tensor(
633+
name="bias_1",
634+
data_type=onnx.TensorProto.FLOAT,
635+
dims=(128,),
636+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
637+
),
638+
helper.make_tensor(
639+
name="bn1_scale",
640+
data_type=onnx.TensorProto.FLOAT,
641+
dims=(128,),
642+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
643+
),
644+
helper.make_tensor(
645+
name="bn1_bias",
646+
data_type=onnx.TensorProto.FLOAT,
647+
dims=(128,),
648+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
649+
),
650+
helper.make_tensor(
651+
name="bn1_mean",
652+
data_type=onnx.TensorProto.FLOAT,
653+
dims=(128,),
654+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
655+
),
656+
helper.make_tensor(
657+
name="bn1_var",
658+
data_type=onnx.TensorProto.FLOAT,
659+
dims=(128,),
660+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
661+
),
662+
helper.make_tensor(
663+
name="weights_2",
664+
data_type=onnx.TensorProto.FLOAT,
665+
dims=(128, 128, 3, 3),
666+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
667+
),
668+
]
669+
670+
# Create the ONNX graph with the nodes and initializers
671+
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)
672+
673+
# Create the ONNX model
674+
model = helper.make_model(graph)
675+
model.opset_import[0].version = 13
676+
model.ir_version = 10
677+
678+
# Check the ONNX model
679+
model_inferred = onnx.shape_inference.infer_shapes(model)
680+
onnx.checker.check_model(model_inferred)
681+
682+
return model_inferred

tests/unit/onnx/test_quantize_int8.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from _test_utils.onnx_quantization.lib_test_models import (
2323
SimpleMLP,
24+
build_conv_act_pool_model,
2425
build_convtranspose_conv_residual_model,
2526
export_as_onnx,
2627
)
@@ -29,13 +30,18 @@
2930
from modelopt.onnx.utils import save_onnx
3031

3132

32-
def _assert_nodes_are_quantized(nodes):
33+
def _assert_nodes_quantization(nodes, should_be_quantized=True):
3334
for node in nodes:
3435
for inp_idx, inp in enumerate(node.inputs):
3536
if isinstance(inp, gs.Variable):
36-
assert node.i(inp_idx).op == "DequantizeLinear", (
37-
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
38-
)
37+
if should_be_quantized:
38+
assert node.i(inp_idx).op == "DequantizeLinear", (
39+
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
40+
)
41+
else:
42+
assert node.i(inp_idx).op != "DequantizeLinear", (
43+
f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!"
44+
)
3945
return True
4046

4147

@@ -59,7 +65,7 @@ def test_int8(tmp_path, high_precision_dtype):
5965

6066
# Check that all MatMul nodes are quantized
6167
mm_nodes = [n for n in graph.nodes if n.op == "MatMul"]
62-
assert _assert_nodes_are_quantized(mm_nodes)
68+
assert _assert_nodes_quantization(mm_nodes)
6369

6470

6571
def test_convtranspose_conv_residual_int8(tmp_path):
@@ -80,7 +86,7 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8086

8187
# Check that Conv and ConvTransposed are quantized
8288
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
83-
assert _assert_nodes_are_quantized(conv_nodes)
89+
assert _assert_nodes_quantization(conv_nodes)
8490

8591
# Check that only 1 input of Add is quantized
8692
add_nodes = [n for n in graph.nodes if n.op == "Add"]
@@ -89,3 +95,28 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8995
assert len(quantized_inputs) == 1, (
9096
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9197
)
98+
99+
100+
def test_conv_act_pool_int8(tmp_path):
101+
onnx_model = build_conv_act_pool_model()
102+
onnx_path = os.path.join(tmp_path, "conv_act_pool_model.onnx")
103+
save_onnx(onnx_model, onnx_path)
104+
105+
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
106+
107+
# Output model should be produced in the same tmp_path
108+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
109+
110+
# Check that quantized explicit model is generated
111+
assert os.path.isfile(output_onnx_path)
112+
113+
# Load the output model and check QDQ node placements
114+
graph = gs.import_onnx(onnx.load(output_onnx_path))
115+
116+
# Check that Conv is quantized
117+
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
118+
assert _assert_nodes_quantization(conv_nodes)
119+
120+
# Check that MaxPool is not quantized
121+
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
122+
assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False)

0 commit comments

Comments
 (0)