Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,35 @@
import onnx_ir.passes.common as common_passes

from onnxscript import ir
from onnxscript.rewriter import (
basic_rules,
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
fuse_pad_into_conv,
fuse_relus_clips,
no_op,
pattern,
redundant_scatter_nd,
)
from onnxscript.rewriter import pattern
from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus
from onnxscript.rewriter._rewrite_rule import (
RewriterContext,
RewriteRule,
RewriteRuleClassBase,
RewriteRuleSet,
)
from onnxscript.rewriter.rules.common import (
_basic_rules,
_broadcast_to_matmul,
_cast_constant_of_shape,
_collapse_slices,
_fuse_pad_into_conv,
_fuse_relus_clips,
_no_op,
_redundant_scatter_nd,
)

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*fuse_relus_clips.fuse_relus_clips_rules().rules,
*basic_rules.basic_optimization_rules().rules,
*redundant_scatter_nd.rules.rules,
*fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules,
*_no_op.rules, # TODO: merge this rule into constant folding?
*_broadcast_to_matmul.rules,
*_cast_constant_of_shape.rules,
*_collapse_slices.rules,
*_fuse_relus_clips.rules,
*_basic_rules.basic_optimization_rules(),
*_redundant_scatter_nd.rules,
*_fuse_pad_into_conv.rules,
)


Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnx_fusions/_onnx_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import onnx_ir as ir

from onnxscript.rewriter.onnx_fusions import _rms_normalization, _rotary_embedding
from onnxscript.rewriter.rules.fusion import _rms_normalization, _rotary_embedding


def _get_onnx_opset_version(model: ir.Model) -> int | None:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from parameterized import parameterized

import onnxscript
import onnxscript.rewriter.onnx_fusions as onnx_fusions
from onnxscript.rewriter import onnx_fusions
from onnxscript.rewriter.models import _rotary_embedding_models


Expand Down
5 changes: 3 additions & 2 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets
import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization
from onnxscript.optimizer import optimize
from onnxscript.rewriter import gemm_to_matmul_add, rewrite
from onnxscript.rewriter import rewrite
from onnxscript.rewriter.ort_fusions import (
instance_to_group_normalization,
softmax,
Expand All @@ -33,6 +33,7 @@
fuse_skip_layer_normalization,
fuse_skip_rms_normalization,
)
from onnxscript.rewriter.rules.common import _gemm_to_matmul_add

ORT_PATTERN_REWRITE_RULES = [
*softmax.rules.rules,
Expand Down Expand Up @@ -133,7 +134,7 @@ def optimize_for_ort(
- The optimized `ir.Model` after applying transformer-specific fusions.
- A dictionary with a count of each of the fusions applied.
"""
rewrite(model, [gemm_to_matmul_add.rule])
rewrite(model, [_gemm_to_matmul_add.gemm_to_matmul_add_rule])
model, fusion_count = fuse_xformers(
model,
debug=debug,
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import onnxscript.optimizer
from onnxscript import FLOAT, ir, script
from onnxscript import opset17 as op
from onnxscript.rewriter import cast_constant_of_shape, pattern
from onnxscript.rewriter import pattern
from onnxscript.rewriter.rules.common import _cast_constant_of_shape

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -306,7 +307,7 @@ def test_delayed_run_provides_correct_bindings_for_multiple_matches(self):
"""
)
model = ir.serde.deserialize_model(model_proto)
count = cast_constant_of_shape.rules.apply_to_model(model)
count = _cast_constant_of_shape.rules.apply_to_model(model)
self.assertEqual(count, 2)
self.assertEqual(len(model.graph), 2)
self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10)
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/rewriter/rules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
103 changes: 103 additions & 0 deletions onnxscript/rewriter/rules/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
__all__ = [
"add_0_rule",
"cast_cast_rule",
"cast_constant_of_shape_rule",
"cast_constant_of_shape_without_value_rule",
"collapse_slice_rule",
"collapse_slice2_rule",
"div_by_1_rule",
"dropout_inference_rule",
"dropout_zero_rule",
"fuse_batchnorm_into_conv_rule",
"fuse_batchnorm_into_conv_transpose_rule",
"fuse_batchnorm_into_gemm_rule",
"fuse_pad_into_conv_integer_rule",
"fuse_pad_into_conv_rule",
"gemm_to_matmul_add_rule",
"matmul_add_to_gemm_rule",
"mul_by_1_rule",
"no_op_cast_rule",
"no_op_dynamic_scatter_nd_rule",
"no_op_expand_rule",
"no_op_static_scatter_nd_rule",
"no_op_transpose_rule",
"normalize_pad_format_conv_integer_rule",
"normalize_pad_format_conv_rule",
"one_reshape_matmul_reshape_rule",
"reshape_reshape_rule",
"slice_split_rule",
"squeeze_reshape_1d_rule",
"sub_0_rule",
"successive_clip_relu_rule",
"successive_clip_rule",
"successive_relu_clip_rule",
"successive_relu_rule",
"transpose_a_matmul_add_to_gemm_rule",
"transpose_ab_matmul_add_to_gemm_rule",
"transpose_b_matmul_add_to_gemm_rule",
"transpose_transpose_rule",
"two_reshapes_matmul_reshape_rule",
"unsqueeze_unsqueeze_rule",
]

from onnxscript.rewriter.rules.common._basic_rules import (
cast_cast_rule,
no_op_cast_rule,
no_op_expand_rule,
no_op_transpose_rule,
reshape_reshape_rule,
slice_split_rule,
squeeze_reshape_1d_rule,
transpose_transpose_rule,
unsqueeze_unsqueeze_rule,
)
from onnxscript.rewriter.rules.common._broadcast_to_matmul import (
one_reshape_matmul_reshape_rule,
two_reshapes_matmul_reshape_rule,
)
from onnxscript.rewriter.rules.common._cast_constant_of_shape import (
cast_constant_of_shape_rule,
cast_constant_of_shape_without_value_rule,
)
from onnxscript.rewriter.rules.common._collapse_slices import (
collapse_slice2_rule,
collapse_slice_rule,
)
from onnxscript.rewriter.rules.common._fuse_batchnorm import (
fuse_batchnorm_into_conv_rule,
fuse_batchnorm_into_conv_transpose_rule,
fuse_batchnorm_into_gemm_rule,
)
from onnxscript.rewriter.rules.common._fuse_pad_into_conv import (
fuse_pad_into_conv_integer_rule,
fuse_pad_into_conv_rule,
normalize_pad_format_conv_integer_rule,
normalize_pad_format_conv_rule,
)
from onnxscript.rewriter.rules.common._fuse_relus_clips import (
successive_clip_relu_rule,
successive_clip_rule,
successive_relu_clip_rule,
successive_relu_rule,
)
from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule
from onnxscript.rewriter.rules.common._matmul_add_to_gemm import (
matmul_add_to_gemm_rule,
transpose_a_matmul_add_to_gemm_rule,
transpose_ab_matmul_add_to_gemm_rule,
transpose_b_matmul_add_to_gemm_rule,
)
from onnxscript.rewriter.rules.common._no_op import (
add_0_rule,
div_by_1_rule,
dropout_inference_rule,
dropout_zero_rule,
mul_by_1_rule,
sub_0_rule,
)
from onnxscript.rewriter.rules.common._redundant_scatter_nd import (
no_op_dynamic_scatter_nd_rule,
no_op_static_scatter_nd_rule,
)
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def check(self, context, x, axes1, axes2) -> MatchResult:

# Create rule instances
cast_cast_rule = CastCast.rule()
cast_identity_rule = CastIdentity.rule()
expand_identity_rule = ExpandIdentity.rule()
no_op_cast_rule = CastIdentity.rule()
no_op_expand_rule = ExpandIdentity.rule()
reshape_reshape_rule = ReshapeReshape.rule()
slice_split_rule = SlicesSplit.rule()
transpose_identity_rule = TransposeIdentity.rule()
no_op_transpose_rule = TransposeIdentity.rule()
transpose_transpose_rule = TransposeTranspose.rule()
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
squeeze_reshape_1d_rule = SqueezeReshape.rule()
Expand All @@ -309,11 +309,11 @@ def basic_optimization_rules() -> RewriteRuleSet:
return RewriteRuleSet(
[
cast_cast_rule,
cast_identity_rule,
expand_identity_rule,
no_op_cast_rule,
no_op_expand_rule,
reshape_reshape_rule,
slice_split_rule,
transpose_identity_rule,
no_op_transpose_rule,
transpose_transpose_rule,
unsqueeze_unsqueeze_rule,
squeeze_reshape_1d_rule,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

import onnxscript
import onnxscript.onnx_types as ot
import onnxscript.rewriter.basic_rules as basic_rules
from onnxscript import ir
from onnxscript.onnx_opset import opset18
from onnxscript.rewriter.rules.common import _basic_rules

FLOAT = onnx.TensorProto.FLOAT

Expand Down Expand Up @@ -98,7 +98,7 @@ def _check_model(
]
)
def test_basic_optimization_rules_identity(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
rule_set = _basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_basic_optimization_rules_identity(self, _: str, model: ir.Model):
]
)
def test_basic_optimization_rules_transpose_transpose(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
rule_set = _basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand All @@ -153,7 +153,7 @@ def cast_cast_model(x):
]
)
def test_cast_cast_rule(self, _: str, type1, type2, type3):
rule = basic_rules.cast_cast_rule
rule = _basic_rules.cast_cast_rule
model_proto = self._double_cast_model(type1, type2, type3)
model = ir.serde.deserialize_model(model_proto)
rule.apply_to_model(model)
Expand All @@ -172,7 +172,7 @@ def test_cast_cast_rule(self, _: str, type1, type2, type3):
]
)
def test_cast_identity_rule(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
rule_set = _basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_cast_identity_rule(self, _: str, model: ir.Model):
def test_expand_identity_rule(
self, _: str, model: ir.Model, expected_nodes: tuple[str, ...]
):
rule_set = basic_rules.basic_optimization_rules()
rule_set = _basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -310,7 +310,7 @@ def test_expand_identity_rule(
]
)
def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
rule_set = _basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
]
)
def test_reshape_reshape_rule(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
rule_set = _basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -420,15 +420,15 @@ def _slices_split_models(cls):
def test_slices_split_rule(self):
for model_proto in self._slices_split_models():
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = basic_rules.basic_optimization_rules()
rule_set = _basic_rules.basic_optimization_rules()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model)

def test_squeeze_reshape_1d_rule(self):
rule = basic_rules.squeeze_reshape_1d_rule
rule = _basic_rules.squeeze_reshape_1d_rule

def check(model_script, expected_count) -> None:
model_proto = model_script.to_model_proto()
Expand Down
Loading
Loading