|
18 | 18 | import numpy as np |
19 | 19 | import onnx |
20 | 20 | import onnx_graphsurgeon as gs |
| 21 | +import pytest |
21 | 22 | from _test_utils.onnx_quantization.lib_test_models import ( |
| 23 | + build_conv_act_pool_model, |
22 | 24 | build_r1a_model, |
23 | 25 | build_resnet_block, |
24 | 26 | build_resnet_block_with_downsample, |
25 | 27 | export_as_onnx, |
26 | 28 | ) |
27 | 29 |
|
28 | 30 | from modelopt.onnx.quantization.quantize import quantize |
| 31 | +from modelopt.onnx.utils import save_onnx |
29 | 32 |
|
30 | 33 |
|
31 | 34 | def _assert_nodes_are_quantized(nodes): |
@@ -119,3 +122,29 @@ def test_resnet_residual_connection_with_downsample(tmp_path): |
119 | 122 | onnx_path = os.path.join(tmp_path, "model.onnx") |
120 | 123 | export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path) |
121 | 124 | _check_resnet_residual_connection(onnx_path) |
| 125 | + |
| 126 | + |
| 127 | +@pytest.mark.parametrize("include_reshape_node", [False, True]) |
| 128 | +def test_conv_act_pool_int8(tmp_path, include_reshape_node): |
| 129 | + onnx_model = build_conv_act_pool_model(include_reshape_node) |
| 130 | + onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx") |
| 131 | + save_onnx(onnx_model, onnx_path) |
| 132 | + |
| 133 | + quantize.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") |
| 134 | + |
| 135 | + # Output model should be produced in the same tmp_path |
| 136 | + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") |
| 137 | + |
| 138 | + # Check that quantized explicit model is generated |
| 139 | + assert os.path.isfile(output_onnx_path) |
| 140 | + |
| 141 | + # Load the output model and check QDQ node placements |
| 142 | + graph = gs.import_onnx(onnx.load(output_onnx_path)) |
| 143 | + |
| 144 | + # Check that Conv is quantized |
| 145 | + conv_nodes = [n for n in graph.nodes if n.op == "Conv"] |
| 146 | + assert _assert_nodes_are_quantized(conv_nodes) |
| 147 | + |
| 148 | + # Check that MaxPool is not quantized |
| 149 | + pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"] |
| 150 | + assert _assert_nodes_are_not_quantized(pool_nodes) |
0 commit comments