|
19 | 19 | import onnx_graphsurgeon as gs |
20 | 20 | import pytest |
21 | 21 | import torch |
22 | | -from _test_utils.onnx.quantization.lib_test_models import ( |
23 | | - SimpleMLP, |
24 | | - build_convtranspose_conv_residual_model, |
25 | | - export_as_onnx, |
26 | | -) |
| 22 | +from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx |
27 | 23 |
|
28 | 24 | import modelopt.onnx.quantization as moq |
29 | | -from modelopt.onnx.utils import save_onnx |
30 | 25 |
|
31 | 26 |
|
32 | 27 | def assert_nodes_are_quantized(nodes): |
@@ -60,32 +55,3 @@ def test_int8(tmp_path, high_precision_dtype): |
60 | 55 | # Check that all MatMul nodes are quantized |
61 | 56 | mm_nodes = [n for n in graph.nodes if n.op == "MatMul"] |
62 | 57 | assert assert_nodes_are_quantized(mm_nodes) |
63 | | - |
64 | | - |
65 | | -def test_convtranspose_conv_residual_int8(tmp_path): |
66 | | - onnx_model = build_convtranspose_conv_residual_model() |
67 | | - onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx") |
68 | | - save_onnx(onnx_model, onnx_path) |
69 | | - |
70 | | - moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") |
71 | | - |
72 | | - # Output model should be produced in the same tmp_path |
73 | | - output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") |
74 | | - |
75 | | - # Check that quantized explicit model is generated |
76 | | - assert os.path.isfile(output_onnx_path) |
77 | | - |
78 | | - # Load the output model and check QDQ node placements |
79 | | - graph = gs.import_onnx(onnx.load(output_onnx_path)) |
80 | | - |
81 | | - # Check that Conv and ConvTransposed are quantized |
82 | | - conv_nodes = [n for n in graph.nodes if "Conv" in n.op] |
83 | | - assert assert_nodes_are_quantized(conv_nodes) |
84 | | - |
85 | | - # Check that only 1 input of Add is quantized |
86 | | - add_nodes = [n for n in graph.nodes if n.op == "Add"] |
87 | | - for node in add_nodes: |
88 | | - quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"] |
89 | | - assert len(quantized_inputs) == 1, ( |
90 | | - f"More than one input of {node.name} is being quantized, but only one should be quantized!" |
91 | | - ) |
0 commit comments