Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/rewriter/rules/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
# Licensed under the MIT License.
__all__ = [
"add_0_rule",
"affine_conv_fusion_rule",
"cast_cast_rule",
"cast_constant_of_shape_rule",
"cast_constant_of_shape_without_value_rule",
"collapse_slice_rule",
"collapse_slice2_rule",
"conv_affine_fusion_rule",
"div_by_1_rule",
"dropout_inference_rule",
"dropout_zero_rule",
"flatten_to_reshape_rule",
"fuse_batchnorm_into_conv_rule",
"fuse_batchnorm_into_conv_transpose_rule",
"fuse_batchnorm_into_gemm_rule",
"fuse_hardswish_rules",
"fuse_pad_into_conv_integer_rule",
"fuse_pad_into_conv_rule",
"min_min_rule",
Expand Down Expand Up @@ -76,6 +79,11 @@
fuse_batchnorm_into_conv_transpose_rule,
fuse_batchnorm_into_gemm_rule,
)
from onnxscript.rewriter.rules.common._fuse_conv_affine import (
affine_conv_fusion_rule,
conv_affine_fusion_rule,
)
from onnxscript.rewriter.rules.common._fuse_hardswish import fuse_hardswish_rules
Copy link
Collaborator

@justinchuby justinchuby Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this: Could you expose the individual rules? I will merge this first - feel free to submit a follow up PR @Stonesjtu

from onnxscript.rewriter.rules.common._fuse_pad_into_conv import (
fuse_pad_into_conv_integer_rule,
fuse_pad_into_conv_rule,
Expand Down
112 changes: 112 additions & 0 deletions onnxscript/rewriter/rules/common/_fuse_conv_affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Absorbs affine operation into convolution (best effort):
- Conv(Mul(Add(x))) -> Conv (only conv without padding can be fused)
- Add(Mul(Conv)) -> Conv (for all convolutions)
"""

from __future__ import annotations

import numpy as np
import onnx_ir as ir

from onnxscript.rewriter import pattern
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._ir_utils import get_const_value, get_singleton_value


class _ConvAffineFusionBase(pattern.RewriteRuleClassBase):
def check(
self,
context,
x: ir.Value,
w: ir.Value,
b: ir.Value,
scale: ir.Value,
offset: ir.Value,
conv_out: ir.Value,
) -> MatchResult:
check_result = MatchResult()
if get_const_value(w) is None:
return check_result.fail("The weight of Conv should be constant")
if get_const_value(b) is None:
return check_result.fail("The bias of Conv should be constant")
if get_singleton_value(scale) is None:
return check_result.fail("Operand for Mul should be constant scalar value")
if get_singleton_value(offset) is None:
return check_result.fail("Operand for Add should be constant scalar value")
return check_result


class AffineConvFusion(_ConvAffineFusionBase):
"""Pattern: scalar Mul + scalar Add + Conv (1x1) --> Conv(1x1)"""

def pattern(
self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value
) -> ir.Value:
return op.Conv(
x * scale + offset,
w,
b,
pads=[0, 0, 0, 0],
_allow_other_attributes=True,
_outputs=["conv_out"],
)

def rewrite(
self,
op: ir.tape.Tape,
x: ir.Value,
w: ir.Value,
b: ir.Value,
scale: ir.Value,
offset: ir.Value,
conv_out: ir.Value,
) -> ir.Value:
scale_value = scale.const_value.numpy()
offset_value = offset.const_value.numpy()
w_value = w.const_value.numpy()
b_value = b.const_value.numpy()
scaled_w_value = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled")
offset_bias = ir.tensor(
b_value + np.sum(w_value * offset_value, axis=(1, 2, 3), keepdims=False)
)
offset_bias = op.initializer(offset_bias, b.name + "_offset")
conv_attributes = conv_out.producer().attributes
return op.Conv(x, scaled_w_value, offset_bias, **conv_attributes)


class ConvAffineFusion(_ConvAffineFusionBase):
"""Pattern: Conv + scalar Mul + scalar Add --> Conv(1x1)"""

def pattern(
self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value
) -> ir.Value:
return (
op.Conv(x, w, b, _allow_other_attributes=True, _outputs=["conv_out"]) * scale
+ offset
)

def rewrite(
self,
op: ir.tape.Tape,
x: ir.Value,
w: ir.Value,
b: ir.Value,
scale: ir.Value,
offset: ir.Value,
conv_out: ir.Value,
) -> ir.Value:
scale_value = scale.const_value.numpy()
offset_value = offset.const_value.numpy()
w_value = w.const_value.numpy()
b_value = b.const_value.numpy()
scaled_w_weight = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled")
offset_bias = ir.tensor(b_value * scale_value + offset_value)
offset_bias = op.initializer(offset_bias, b.name + "_offset")
conv_attributes = conv_out.producer().attributes
return op.Conv(x, scaled_w_weight, offset_bias, **conv_attributes)


affine_conv_fusion_rule = AffineConvFusion().rule()
conv_affine_fusion_rule = ConvAffineFusion().rule()
115 changes: 115 additions & 0 deletions onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest

import numpy as np

from onnxscript import ir
from onnxscript.rewriter import rewrite, testing
from onnxscript.rewriter.rules.common import (
affine_conv_fusion_rule,
conv_affine_fusion_rule,
)


class FuseConvAffineTest(unittest.TestCase):
def clone_model(self, model: ir.Model) -> ir.Model:
return ir.from_proto(ir.to_proto(model))

def test_conv_affine_fusion(self):
tape = ir.tape.Tape()
x = ir.Input(
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
)
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset"))

conv_out = tape.op("Conv", [x, w, b], attributes={"pads": [1, 1, 1, 1]})
mul_out = tape.op("Mul", [conv_out, scale])
z = tape.op(
"Add",
[mul_out, offset],
output=ir.Input(
"z",
shape=ir.Shape([1, 3, 32, 32]),
type=ir.TensorType(ir.DataType.FLOAT),
),
)

model = ir.Model(
ir.Graph(
inputs=[x],
outputs=[z],
nodes=tape.nodes,
initializers=tape.initializers,
opset_imports={"": 17},
),
ir_version=8,
)
rewritten_model = self.clone_model(model)
rewritten_model = rewrite(
rewritten_model,
pattern_rewrite_rules=[conv_affine_fusion_rule],
)
# Check that Mul and Add are fused into Conv
self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes())

# Check that the results are numerically equal
rng = np.random.default_rng(42)
inputs = [
rng.random((1, 3, 32, 32), dtype=np.float32),
]
testing.assert_numerically_equal(model, rewritten_model, inputs)

def test_affine_conv_fusion_without_pad(self):
tape = ir.tape.Tape()
x = ir.Input(
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
)
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset"))

mul_out = tape.op("Mul", [x, scale])
z = tape.op(
"Add",
[mul_out, offset],
output=ir.Input(
"z",
shape=ir.Shape([1, 3, 32, 32]),
type=ir.TensorType(ir.DataType.FLOAT),
),
)
conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]})

model = ir.Model(
ir.Graph(
inputs=[x],
outputs=[conv_out],
nodes=tape.nodes,
initializers=tape.initializers,
opset_imports={"": 17},
),
ir_version=8,
)
rewritten_model = self.clone_model(model)
rewritten_model = rewrite(
rewritten_model,
pattern_rewrite_rules=[affine_conv_fusion_rule],
)
# Check that Mul and Add are fused into Conv
self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes())

# Check that the results are numerically equal
rng = np.random.default_rng(42)
inputs = [
rng.random((1, 3, 32, 32), dtype=np.float32),
]
testing.assert_numerically_equal(model, rewritten_model, inputs)


if __name__ == "__main__":
unittest.main()
Loading
Loading