Skip to content

Commit 12ac717

Browse files
committed
Added unittest
Signed-off-by: gcunhase <[email protected]>
1 parent a2b5ebb commit 12ac717

File tree

2 files changed

+162
-6
lines changed

2 files changed

+162
-6
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,3 +554,128 @@ def build_convtranspose_conv_residual_model():
554554
onnx.checker.check_model(model_inferred)
555555

556556
return model_inferred
557+
558+
559+
def build_conv_act_pool_model():
560+
# Define your model inputs and outputs
561+
input_names = ["input_0"]
562+
output_names = ["output_0"]
563+
input_shapes = [(32, 64, 256, 256)]
564+
output_shapes = [(32, 128, 128, 128)]
565+
566+
inputs = [
567+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
568+
for input_name, input_shape in zip(input_names, input_shapes)
569+
]
570+
outputs = [
571+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
572+
for output_name, output_shape in zip(output_names, output_shapes)
573+
]
574+
575+
# Create the ONNX graph with the nodes
576+
nodes = [
577+
helper.make_node(
578+
op_type="Conv",
579+
inputs=["input_0", "weights_1", "bias_1"],
580+
outputs=["conv1_conv/Conv2D:0"],
581+
name="conv1_conv/Conv2D",
582+
dilations=[1, 1],
583+
group=1,
584+
kernel_shape=[3, 3],
585+
pads=[0, 0, 0, 0],
586+
strides=[1, 1],
587+
),
588+
helper.make_node(
589+
op_type="BatchNormalization",
590+
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
591+
outputs=["bn1_batchnorm/BatchNormalization:0"],
592+
name="bn1_batchnorm/BatchNormalization",
593+
),
594+
helper.make_node(
595+
op_type="Relu",
596+
inputs=["bn1_batchnorm/BatchNormalization:0"],
597+
outputs=["relu1_relu/Relu:0"],
598+
name="relu1_relu/Relu",
599+
),
600+
helper.make_node(
601+
op_type="MaxPool",
602+
inputs=["relu1_relu/Relu:0"],
603+
outputs=["maxpool1_maxpool/MaxPool2D:0"],
604+
name="maxpool1_maxpool/MaxPool2D",
605+
ceil_mode=False,
606+
kernel_shape=[3, 3],
607+
pads=[0, 0, 0, 0],
608+
strides=[2, 2],
609+
),
610+
helper.make_node(
611+
op_type="Conv",
612+
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
613+
outputs=["output_0"],
614+
name="conv2_conv/Conv2D",
615+
dilations=[1, 1],
616+
group=1,
617+
kernel_shape=[3, 3],
618+
pads=[0, 0, 0, 0],
619+
strides=[1, 1],
620+
),
621+
]
622+
623+
# Create the ONNX initializers
624+
initializers = [
625+
helper.make_tensor(
626+
name="weights_1",
627+
data_type=onnx.TensorProto.FLOAT,
628+
dims=(128, 64, 3, 3),
629+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3),
630+
),
631+
helper.make_tensor(
632+
name="bias_1",
633+
data_type=onnx.TensorProto.FLOAT,
634+
dims=(128,),
635+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
636+
),
637+
helper.make_tensor(
638+
name="bn1_scale",
639+
data_type=onnx.TensorProto.FLOAT,
640+
dims=(128,),
641+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
642+
),
643+
helper.make_tensor(
644+
name="bn1_bias",
645+
data_type=onnx.TensorProto.FLOAT,
646+
dims=(128,),
647+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
648+
),
649+
helper.make_tensor(
650+
name="bn1_mean",
651+
data_type=onnx.TensorProto.FLOAT,
652+
dims=(128,),
653+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
654+
),
655+
helper.make_tensor(
656+
name="bn1_var",
657+
data_type=onnx.TensorProto.FLOAT,
658+
dims=(128,),
659+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
660+
),
661+
helper.make_tensor(
662+
name="weights_2",
663+
data_type=onnx.TensorProto.FLOAT,
664+
dims=(128, 128, 3, 3),
665+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
666+
),
667+
]
668+
669+
# Create the ONNX graph with the nodes and initializers
670+
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)
671+
672+
# Create the ONNX model
673+
model = helper.make_model(graph)
674+
model.opset_import[0].version = 13
675+
model.ir_version = 10
676+
677+
# Check the ONNX model
678+
model_inferred = onnx.shape_inference.infer_shapes(model)
679+
onnx.checker.check_model(model_inferred)
680+
681+
return model_inferred

tests/unit/onnx/test_quantize_int8.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from _test_utils.onnx_quantization.lib_test_models import (
2323
SimpleMLP,
24+
build_conv_act_pool_model,
2425
build_convtranspose_conv_residual_model,
2526
export_as_onnx,
2627
)
@@ -29,13 +30,18 @@
2930
from modelopt.onnx.utils import save_onnx
3031

3132

32-
def _assert_nodes_are_quantized(nodes):
33+
def _assert_nodes_quantization(nodes, should_be_quantized=True):
3334
for node in nodes:
3435
for inp_idx, inp in enumerate(node.inputs):
3536
if isinstance(inp, gs.Variable):
36-
assert node.i(inp_idx).op == "DequantizeLinear", (
37-
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
38-
)
37+
if should_be_quantized:
38+
assert node.i(inp_idx).op == "DequantizeLinear", (
39+
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
40+
)
41+
else:
42+
assert node.i(inp_idx).op != "DequantizeLinear", (
43+
f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!"
44+
)
3945
return True
4046

4147

@@ -59,7 +65,7 @@ def test_int8(tmp_path, high_precision_dtype):
5965

6066
# Check that all MatMul nodes are quantized
6167
mm_nodes = [n for n in graph.nodes if n.op == "MatMul"]
62-
assert _assert_nodes_are_quantized(mm_nodes)
68+
assert _assert_nodes_quantization(mm_nodes)
6369

6470

6571
def test_convtranspose_conv_residual_int8(tmp_path):
@@ -80,7 +86,7 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8086

8187
# Check that Conv and ConvTransposed are quantized
8288
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
83-
assert _assert_nodes_are_quantized(conv_nodes)
89+
assert _assert_nodes_quantization(conv_nodes)
8490

8591
# Check that only 1 input of Add is quantized
8692
add_nodes = [n for n in graph.nodes if n.op == "Add"]
@@ -89,3 +95,28 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8995
assert len(quantized_inputs) == 1, (
9096
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9197
)
98+
99+
100+
def test_conv_act_pool_int8(tmp_path):
101+
onnx_model = build_conv_act_pool_model()
102+
onnx_path = os.path.join(tmp_path, "conv_act_pool_model.onnx")
103+
save_onnx(onnx_model, onnx_path)
104+
105+
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
106+
107+
# Output model should be produced in the same tmp_path
108+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
109+
110+
# Check that quantized explicit model is generated
111+
assert os.path.isfile(output_onnx_path)
112+
113+
# Load the output model and check QDQ node placements
114+
graph = gs.import_onnx(onnx.load(output_onnx_path))
115+
116+
# Check that Conv is quantized
117+
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
118+
assert _assert_nodes_quantization(conv_nodes)
119+
120+
# Check that MaxPool is not quantized
121+
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
122+
assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False)

0 commit comments

Comments
 (0)