Skip to content

Commit 4b77ae4

Browse files
committed
Add unittest
Signed-off-by: gcunhase <[email protected]>
1 parent d0e83ed commit 4b77ae4

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,115 @@ def build_convtranspose_conv_residual_model():
555555
onnx.checker.check_model(model_inferred)
556556

557557
return model_inferred
558+
559+
560+
def build_conv_batchnorm_sig_mul_model():
561+
# Define your model inputs and outputs
562+
input_names = ["input_0"]
563+
output_names = ["output_0"]
564+
input_shapes = [(6, 48, 64, 176)]
565+
output_shapes = [(6, 48, 64, 176)]
566+
567+
inputs = [
568+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
569+
for input_name, input_shape in zip(input_names, input_shapes)
570+
]
571+
outputs = [
572+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
573+
for output_name, output_shape in zip(output_names, output_shapes)
574+
]
575+
576+
# Create the ONNX graph with the nodes
577+
nodes = [
578+
helper.make_node(
579+
op_type="Conv",
580+
inputs=["input_0", "weights_1"],
581+
outputs=["conv1_conv/Conv2D:0"],
582+
name="conv1_conv/Conv2D",
583+
dilations=[1, 1],
584+
group=1,
585+
kernel_shape=[3, 3],
586+
pads=[1, 1, 1, 1],
587+
strides=[1, 1],
588+
),
589+
helper.make_node(
590+
op_type="BatchNormalization",
591+
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
592+
outputs=["bn1_batchnorm/BatchNormalization:0"],
593+
name="bn1_batchnorm/BatchNormalization",
594+
),
595+
helper.make_node(
596+
op_type="Sigmoid",
597+
inputs=["bn1_batchnorm/BatchNormalization:0"],
598+
outputs=["sig1_sigmoid/Sigmoid:0"],
599+
name="sig1_sigmoid/Sigmoid",
600+
),
601+
helper.make_node(
602+
op_type="Mul",
603+
inputs=["sig1_sigmoid/Sigmoid:0", "bn1_batchnorm/BatchNormalization:0"],
604+
outputs=["mul1_mul/Mul:0"],
605+
name="mul1_mul/Mul",
606+
),
607+
helper.make_node(
608+
op_type="Add",
609+
inputs=["input_0", "mul1_mul/Mul:0"],
610+
outputs=["add1_add/Add:0"],
611+
name="add1_add/Add",
612+
),
613+
helper.make_node(
614+
op_type="Relu",
615+
inputs=["add1_add/Add:0"],
616+
outputs=["output_0"],
617+
name="relu2_relu/Relu",
618+
),
619+
]
620+
621+
# Create the ONNX initializers
622+
initializers = [
623+
helper.make_tensor(
624+
name="weights_1",
625+
data_type=onnx.TensorProto.FLOAT,
626+
dims=(48, 48, 3, 3),
627+
vals=np.random.uniform(low=0.5, high=1.0, size=48 * 48 * 3 * 3),
628+
),
629+
helper.make_tensor(
630+
name="bn1_scale",
631+
data_type=onnx.TensorProto.FLOAT,
632+
dims=(48,),
633+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
634+
),
635+
helper.make_tensor(
636+
name="bn1_bias",
637+
data_type=onnx.TensorProto.FLOAT,
638+
dims=(48,),
639+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
640+
),
641+
helper.make_tensor(
642+
name="bn1_mean",
643+
data_type=onnx.TensorProto.FLOAT,
644+
dims=(48,),
645+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
646+
),
647+
helper.make_tensor(
648+
name="bn1_var",
649+
data_type=onnx.TensorProto.FLOAT,
650+
dims=(48,),
651+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
652+
),
653+
]
654+
655+
# Create the ONNX graph with the nodes and initializers
656+
graph = helper.make_graph(
657+
nodes, "conv_batchnorm_sig_mul", inputs, outputs, initializer=initializers
658+
)
659+
660+
# Create the ONNX model
661+
model = helper.make_model(graph)
662+
model.opset_import[0].version = 13
663+
model.ir_version = 10
664+
665+
# Check the ONNX model
666+
model_inferred = onnx.shape_inference.infer_shapes(model)
667+
onnx.checker.check_model(model_inferred)
668+
669+
return model_inferred

tests/unit/onnx/test_quantize_int8.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from _test_utils.onnx_quantization.lib_test_models import (
2323
SimpleMLP,
24+
build_conv_batchnorm_sig_mul_model,
2425
build_convtranspose_conv_residual_model,
2526
export_as_onnx,
2627
)
@@ -89,3 +90,32 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8990
assert len(quantized_inputs) == 1, (
9091
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9192
)
93+
94+
95+
def test_conv_batchnorm_sig_mul_int8(tmp_path="./"):
96+
onnx_model = build_conv_batchnorm_sig_mul_model()
97+
onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx")
98+
save_onnx(onnx_model, onnx_path)
99+
100+
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
101+
102+
# Output model should be produced in the same tmp_path
103+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
104+
105+
# Check that quantized explicit model is generated
106+
assert os.path.isfile(output_onnx_path)
107+
108+
# Load the output model and check QDQ node placements
109+
graph = gs.import_onnx(onnx.load(output_onnx_path))
110+
111+
# Check that Conv and ConvTransposed are quantized
112+
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
113+
assert _assert_nodes_are_quantized(conv_nodes)
114+
115+
# Check that only 1 input of Add is quantized
116+
add_nodes = [n for n in graph.nodes if n.op == "Add"]
117+
for node in add_nodes:
118+
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
119+
assert len(quantized_inputs) == 1, (
120+
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
121+
)

0 commit comments

Comments
 (0)