Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions modelopt/onnx/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def is_fusible_scaling_op(op_type: str):
]


def is_copy_op(op_type: str):
"""Returns whether the given op is a copy operator or not."""
return op_type in [
def get_copy_ops():
"""Returns list of copy operators."""
return [
"Flatten",
"Transpose",
"Concat",
Expand All @@ -118,6 +118,11 @@ def is_copy_op(op_type: str):
]


def is_copy_op(op_type: str):
"""Returns whether the given op is a copy operator or not."""
return op_type in get_copy_ops()


def is_linear_op(op_type: str):
"""Returns whether the given op type is of Linear category or not."""
return op_type in ["Conv", "ConvTranspose", "Gemm", "MatMul"]
Expand Down
9 changes: 4 additions & 5 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from onnxruntime.quantization.calibrate import CalibrationDataReader

from modelopt.onnx.logging_config import logger
from modelopt.onnx.op_types import is_copy_op, is_linear_op
from modelopt.onnx.op_types import get_copy_ops, is_copy_op, is_linear_op
from modelopt.onnx.quantization.ort_utils import create_inference_session
from modelopt.onnx.utils import (
find_lowest_common_ancestor,
Expand Down Expand Up @@ -173,7 +173,7 @@ def has_path_type(
def get_fusible_backbone(node: Node, graph: Graph) -> Node | None:
"""Returns the linear backbone node for a given node if it matches the pattern.

TensorRT fuses convolution with BN, Relu etc. when in some specific pattern.
TensorRT fuses convolution with BN, Relu, MaxPool etc. when in some specific pattern.
This rule tries to match some of those patterns.
Note. BiasAdd and ConstMul are optional in path types.

Expand All @@ -190,7 +190,7 @@ def _get_backbone(root: Node):
return root

for tensor in root.inputs:
if not isinstance(tensor, Constant):
if not isinstance(tensor, Constant) and tensor.inputs:
parent_node = tensor.inputs[0]
bb = _get_backbone(parent_node)
if bb:
Expand All @@ -207,7 +207,7 @@ def _get_backbone(root: Node):
["Mul", "Sigmoid", "BatchNormalization", conv_type],
]
for idx, path_type in enumerate(fusible_linear_path_types):
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]):
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=get_copy_ops()):
return _get_backbone(node)

return None
Expand Down Expand Up @@ -1002,7 +1002,6 @@ def find_nodes_from_matmul_to_exclude(
logger.debug("No MatMul nodes found in the model")
return []

nodes_to_exclude = []
logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze")

if calibration_shapes:
Expand Down
13 changes: 9 additions & 4 deletions modelopt/onnx/quantization/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def _build_fusible_partition(
"""Traverses the graph starting from cur_node and updates the fusible_partition list.

Add a nodes to the partition if any of these holds:
1. The node is a unary or binary pointwise operation and fusible by cask
2. The node is BN and/or Relu and fusible with preceding Conv op
3. The node is a residual Add and fusible with current partition
1. The node is a unary or binary pointwise operation or a copy op and fusible by cask
2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion)
3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion)
4. The node is a residual Add and fusible with current partition

Args:
cur_node: Current candidate node for the partition.
Expand Down Expand Up @@ -131,11 +132,15 @@ def _is_fusible_mul(mul_node: Node) -> bool:

if (
(
is_copy_op(consumer_node.op)
and _is_cask_fusible(consumer_node, partition_node_outputs)
)
or (
is_pointwise_or_elementwise_op(consumer_node.op)
and _is_cask_fusible(consumer_node, partition_node_outputs)
)
or (
consumer_node.op in ["BatchNormalization", "Relu"]
consumer_node.op in ["BatchNormalization", "Relu", "MaxPool"]
and get_fusible_backbone(consumer_node, graph)
)
or _is_on_non_residual_path(consumer_node)
Expand Down
4 changes: 1 addition & 3 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ def get_dynamic_graph_inputs(onnx_model: onnx.ModelProto):
List of dynamic inputs.
"""
graph = gs.import_onnx(onnx_model)
return [
inp for inp in graph.inputs if -1 in inp.shape or any(isinstance(s, str) for s in inp.shape)
]
return [inp for inp in graph.inputs if any(isinstance(s, str) or s <= 0 for s in inp.shape)]


def _get_all_shapes(container: Any) -> dict[str, list[int]]:
Expand Down
149 changes: 149 additions & 0 deletions tests/_test_utils/onnx/quantization/lib_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,152 @@ def build_conv_batchnorm_sig_mul_model():
onnx.checker.check_model(model_inferred)

return model_inferred


def build_conv_act_pool_model(include_reshape_node=False):
# Define your model inputs and outputs
input_names = ["input_0"]
output_names = ["output_0"]
input_shapes = [(32, 64, 256, 256)]
output_shapes = [(32, 128, 128, 128)]

inputs = [
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
for input_name, input_shape in zip(input_names, input_shapes)
]
outputs = [
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
for output_name, output_shape in zip(output_names, output_shapes)
]

# Create the ONNX graph with the nodes
nodes = [
helper.make_node(
op_type="Conv",
inputs=["input_0", "weights_1", "bias_1"],
outputs=["conv1_conv/Conv2D:0"],
name="conv1_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
helper.make_node(
op_type="BatchNormalization",
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
outputs=["bn1_batchnorm/BatchNormalization:0"],
name="bn1_batchnorm/BatchNormalization",
),
helper.make_node(
op_type="Relu",
inputs=["bn1_batchnorm/BatchNormalization:0"],
outputs=["relu1_relu/Relu:0"],
name="relu1_relu/Relu",
),
]
if include_reshape_node:
nodes.append(
helper.make_node(
op_type="Reshape",
inputs=["relu1_relu/Relu:0", "shape_1"],
outputs=["reshape1_reshape/Reshape:0"],
name="reshape1_reshape/Reshape",
),
)
nodes.extend(
[
helper.make_node(
op_type="MaxPool",
inputs=[
"reshape1_reshape/Reshape:0" if include_reshape_node else "relu1_relu/Relu:0"
],
outputs=["maxpool1_maxpool/MaxPool2D:0"],
name="maxpool1_maxpool/MaxPool2D",
ceil_mode=False,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[2, 2],
),
helper.make_node(
op_type="Conv",
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
outputs=["output_0"],
name="conv2_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
]
)

# Create the ONNX initializers
initializers = [
helper.make_tensor(
name="weights_1",
data_type=onnx.TensorProto.FLOAT,
dims=(128, 64, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3),
),
helper.make_tensor(
name="bias_1",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_scale",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_bias",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_mean",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_var",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="weights_2",
data_type=onnx.TensorProto.FLOAT,
dims=(128, 128, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
),
]
if include_reshape_node:
initializers.append(
helper.make_tensor(
name="shape_1",
data_type=onnx.TensorProto.INT64,
dims=(4,),
vals=(32, 128, 256, 256),
),
)

# Create the ONNX graph with the nodes and initializers
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)

# Create the ONNX model
model = helper.make_model(graph)
model.opset_import[0].version = 13
model.ir_version = 10

# Check the ONNX model
model_inferred = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model_inferred)

return model_inferred
64 changes: 61 additions & 3 deletions tests/unit/onnx/test_qdq_rules_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import pytest
from _test_utils.onnx.quantization.lib_test_models import (
build_conv_act_pool_model,
build_conv_batchnorm_sig_mul_model,
build_convtranspose_conv_residual_model,
build_r1a_model,
build_resnet_block,
build_resnet_block_with_downsample,
Expand All @@ -40,7 +43,7 @@ def assert_nodes_are_quantized(nodes):
return True


def _assert_nodes_are_not_quantized(nodes):
def assert_nodes_are_not_quantized(nodes):
for node in nodes:
for inp_idx, inp in enumerate(node.inputs):
if isinstance(inp, gs.Variable) and inp.inputs:
Expand Down Expand Up @@ -76,7 +79,7 @@ def test_bias_add_rule(tmp_path):
other_nodes = [
n for n in graph.nodes if n.op not in ["Conv", "QuantizeLinear", "DequantizeLinear"]
]
assert _assert_nodes_are_not_quantized(other_nodes)
assert assert_nodes_are_not_quantized(other_nodes)


def _check_resnet_residual_connection(onnx_path):
Expand Down Expand Up @@ -106,7 +109,7 @@ def _check_resnet_residual_connection(onnx_path):
other_nodes = [
n for n in graph.nodes if n.op not in ["Conv", "Add", "QuantizeLinear", "DequantizeLinear"]
]
assert _assert_nodes_are_not_quantized(other_nodes)
assert assert_nodes_are_not_quantized(other_nodes)


def test_resnet_residual_connections(tmp_path):
Expand All @@ -123,6 +126,35 @@ def test_resnet_residual_connection_with_downsample(tmp_path):
_check_resnet_residual_connection(onnx_path)


def test_convtranspose_conv_residual_int8(tmp_path):
onnx_model = build_convtranspose_conv_residual_model()
onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx")
save_onnx(onnx_model, onnx_path)

quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")

# Output model should be produced in the same tmp_path
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")

# Check that quantized explicit model is generated
assert os.path.isfile(output_onnx_path)

# Load the output model and check QDQ node placements
graph = gs.import_onnx(onnx.load(output_onnx_path))

# Check that Conv and ConvTransposed are quantized
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
assert assert_nodes_are_quantized(conv_nodes)

# Check that only 1 input of Add is quantized
add_nodes = [n for n in graph.nodes if n.op == "Add"]
for node in add_nodes:
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
assert len(quantized_inputs) == 1, (
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
)


def test_conv_batchnorm_sig_mul_int8(tmp_path):
onnx_model = build_conv_batchnorm_sig_mul_model()
onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx")
Expand Down Expand Up @@ -150,3 +182,29 @@ def test_conv_batchnorm_sig_mul_int8(tmp_path):
assert len(quantized_inputs) == 1, (
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
)


@pytest.mark.parametrize("include_reshape_node", [False, True])
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
onnx_model = build_conv_act_pool_model(include_reshape_node)
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
save_onnx(onnx_model, onnx_path)

quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")

# Output model should be produced in the same tmp_path
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")

# Check that quantized explicit model is generated
assert os.path.isfile(output_onnx_path)

# Load the output model and check QDQ node placements
graph = gs.import_onnx(onnx.load(output_onnx_path))

# Check that Conv is quantized
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
assert assert_nodes_are_quantized(conv_nodes)

# Check that MaxPool is not quantized
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
assert assert_nodes_are_not_quantized(pool_nodes)
Loading