|
19 | 19 | import onnx |
20 | 20 | import onnx_graphsurgeon as gs |
21 | 21 | from _test_utils.onnx_quantization.lib_test_models import ( |
| 22 | + build_conv_batchnorm_sig_mul_model, |
22 | 23 | build_r1a_model, |
23 | 24 | build_resnet_block, |
24 | 25 | build_resnet_block_with_downsample, |
25 | 26 | export_as_onnx, |
26 | 27 | ) |
27 | 28 |
|
28 | 29 | from modelopt.onnx.quantization.quantize import quantize |
| 30 | +from modelopt.onnx.utils import save_onnx |
29 | 31 |
|
30 | 32 |
|
31 | 33 | def _assert_nodes_are_quantized(nodes): |
@@ -119,3 +121,32 @@ def test_resnet_residual_connection_with_downsample(tmp_path): |
119 | 121 | onnx_path = os.path.join(tmp_path, "model.onnx") |
120 | 122 | export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path) |
121 | 123 | _check_resnet_residual_connection(onnx_path) |
| 124 | + |
| 125 | + |
| 126 | +def test_conv_batchnorm_sig_mul_int8(tmp_path): |
| 127 | + onnx_model = build_conv_batchnorm_sig_mul_model() |
| 128 | + onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx") |
| 129 | + save_onnx(onnx_model, onnx_path) |
| 130 | + |
| 131 | + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") |
| 132 | + |
| 133 | + # Output model should be produced in the same tmp_path |
| 134 | + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") |
| 135 | + |
| 136 | + # Check that quantized explicit model is generated |
| 137 | + assert os.path.isfile(output_onnx_path) |
| 138 | + |
| 139 | + # Load the output model and check QDQ node placements |
| 140 | + graph = gs.import_onnx(onnx.load(output_onnx_path)) |
| 141 | + |
| 142 | + # Check that Conv and ConvTransposed are quantized |
| 143 | + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] |
| 144 | + assert _assert_nodes_are_quantized(conv_nodes) |
| 145 | + |
| 146 | + # Check that only 1 input of Add is quantized |
| 147 | + add_nodes = [n for n in graph.nodes if n.op == "Add"] |
| 148 | + for node in add_nodes: |
| 149 | + quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"] |
| 150 | + assert len(quantized_inputs) == 1, ( |
| 151 | + f"More than one input of {node.name} is being quantized, but only one should be quantized!" |
| 152 | + ) |
0 commit comments