Skip to content

Commit 388e02d

Browse files
committed
Added unittest
Signed-off-by: gcunhase <[email protected]>
1 parent f10e80a commit 388e02d

File tree

2 files changed

+159
-3
lines changed

2 files changed

+159
-3
lines changed

tests/_test_utils/onnx/quantization/lib_test_models.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,3 +673,128 @@ def build_conv_batchnorm_sig_mul_model():
673673
onnx.checker.check_model(model_inferred)
674674

675675
return model_inferred
676+
677+
678+
def build_conv_act_pool_model():
679+
# Define your model inputs and outputs
680+
input_names = ["input_0"]
681+
output_names = ["output_0"]
682+
input_shapes = [(32, 64, 256, 256)]
683+
output_shapes = [(32, 128, 128, 128)]
684+
685+
inputs = [
686+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
687+
for input_name, input_shape in zip(input_names, input_shapes)
688+
]
689+
outputs = [
690+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
691+
for output_name, output_shape in zip(output_names, output_shapes)
692+
]
693+
694+
# Create the ONNX graph with the nodes
695+
nodes = [
696+
helper.make_node(
697+
op_type="Conv",
698+
inputs=["input_0", "weights_1", "bias_1"],
699+
outputs=["conv1_conv/Conv2D:0"],
700+
name="conv1_conv/Conv2D",
701+
dilations=[1, 1],
702+
group=1,
703+
kernel_shape=[3, 3],
704+
pads=[0, 0, 0, 0],
705+
strides=[1, 1],
706+
),
707+
helper.make_node(
708+
op_type="BatchNormalization",
709+
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
710+
outputs=["bn1_batchnorm/BatchNormalization:0"],
711+
name="bn1_batchnorm/BatchNormalization",
712+
),
713+
helper.make_node(
714+
op_type="Relu",
715+
inputs=["bn1_batchnorm/BatchNormalization:0"],
716+
outputs=["relu1_relu/Relu:0"],
717+
name="relu1_relu/Relu",
718+
),
719+
helper.make_node(
720+
op_type="MaxPool",
721+
inputs=["relu1_relu/Relu:0"],
722+
outputs=["maxpool1_maxpool/MaxPool2D:0"],
723+
name="maxpool1_maxpool/MaxPool2D",
724+
ceil_mode=False,
725+
kernel_shape=[3, 3],
726+
pads=[0, 0, 0, 0],
727+
strides=[2, 2],
728+
),
729+
helper.make_node(
730+
op_type="Conv",
731+
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
732+
outputs=["output_0"],
733+
name="conv2_conv/Conv2D",
734+
dilations=[1, 1],
735+
group=1,
736+
kernel_shape=[3, 3],
737+
pads=[0, 0, 0, 0],
738+
strides=[1, 1],
739+
),
740+
]
741+
742+
# Create the ONNX initializers
743+
initializers = [
744+
helper.make_tensor(
745+
name="weights_1",
746+
data_type=onnx.TensorProto.FLOAT,
747+
dims=(128, 64, 3, 3),
748+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3),
749+
),
750+
helper.make_tensor(
751+
name="bias_1",
752+
data_type=onnx.TensorProto.FLOAT,
753+
dims=(128,),
754+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
755+
),
756+
helper.make_tensor(
757+
name="bn1_scale",
758+
data_type=onnx.TensorProto.FLOAT,
759+
dims=(128,),
760+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
761+
),
762+
helper.make_tensor(
763+
name="bn1_bias",
764+
data_type=onnx.TensorProto.FLOAT,
765+
dims=(128,),
766+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
767+
),
768+
helper.make_tensor(
769+
name="bn1_mean",
770+
data_type=onnx.TensorProto.FLOAT,
771+
dims=(128,),
772+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
773+
),
774+
helper.make_tensor(
775+
name="bn1_var",
776+
data_type=onnx.TensorProto.FLOAT,
777+
dims=(128,),
778+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
779+
),
780+
helper.make_tensor(
781+
name="weights_2",
782+
data_type=onnx.TensorProto.FLOAT,
783+
dims=(128, 128, 3, 3),
784+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
785+
),
786+
]
787+
788+
# Create the ONNX graph with the nodes and initializers
789+
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)
790+
791+
# Create the ONNX model
792+
model = helper.make_model(graph)
793+
model.opset_import[0].version = 13
794+
model.ir_version = 10
795+
796+
# Check the ONNX model
797+
model_inferred = onnx.shape_inference.infer_shapes(model)
798+
onnx.checker.check_model(model_inferred)
799+
800+
return model_inferred

tests/unit/onnx/test_quantize_int8.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from _test_utils.onnx.quantization.lib_test_models import (
2323
SimpleMLP,
24+
build_conv_act_pool_model,
2425
build_convtranspose_conv_residual_model,
2526
export_as_onnx,
2627
)
@@ -33,9 +34,14 @@ def assert_nodes_are_quantized(nodes):
3334
for node in nodes:
3435
for inp_idx, inp in enumerate(node.inputs):
3536
if isinstance(inp, gs.Variable):
36-
assert node.i(inp_idx).op == "DequantizeLinear", (
37-
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
38-
)
37+
if should_be_quantized:
38+
assert node.i(inp_idx).op == "DequantizeLinear", (
39+
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
40+
)
41+
else:
42+
assert node.i(inp_idx).op != "DequantizeLinear", (
43+
f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!"
44+
)
3945
return True
4046

4147

@@ -89,3 +95,28 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8995
assert len(quantized_inputs) == 1, (
9096
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9197
)
98+
99+
100+
def test_conv_act_pool_int8(tmp_path):
101+
onnx_model = build_conv_act_pool_model()
102+
onnx_path = os.path.join(tmp_path, "conv_act_pool_model.onnx")
103+
save_onnx(onnx_model, onnx_path)
104+
105+
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
106+
107+
# Output model should be produced in the same tmp_path
108+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
109+
110+
# Check that quantized explicit model is generated
111+
assert os.path.isfile(output_onnx_path)
112+
113+
# Load the output model and check QDQ node placements
114+
graph = gs.import_onnx(onnx.load(output_onnx_path))
115+
116+
# Check that Conv is quantized
117+
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
118+
assert _assert_nodes_quantization(conv_nodes)
119+
120+
# Check that MaxPool is not quantized
121+
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
122+
assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False)

0 commit comments

Comments
 (0)