Skip to content

Commit 0d1c0df

Browse files
committed
Move into onnx fusion folder
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 5cbb4d4 commit 0d1c0df

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

onnxscript/rewriter/layer_normalization.py renamed to onnxscript/rewriter/onnx_fusions/_layer_norm.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434

3535

3636
class LayerNormFusion(pattern.RewriteRuleClassBase):
37+
def __init__(self, name: str, has_bias: bool):
38+
super().__init__(name)
39+
self._has_bias = has_bias
40+
self._stash_dtype: int | None = None
41+
3742
def pattern(self, op, x, scale, bias, epsilon, target_dtype):
3843
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
3944
# TODO: support axes attribute too
@@ -43,8 +48,10 @@ def pattern(self, op, x, scale, bias, epsilon, target_dtype):
4348
deviation = op.Sub(x, mean)
4449

4550
# Compute squared deviation: DD = Mul(D, D)
46-
# TODO: support Pow (D, 2) as well
47-
deviation_squared = op.Mul(deviation, deviation)
51+
deviation_squared = pattern.OrValue([
52+
op.Mul(deviation, deviation),
53+
op.Pow(deviation, 2),
54+
])
4855

4956
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
5057
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
@@ -56,24 +63,25 @@ def pattern(self, op, x, scale, bias, epsilon, target_dtype):
5663
std_dev = op.Sqrt(variance_plus_epsilon)
5764

5865
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
59-
# TODO: support Div(deviation, std_dev) as well?
60-
inv_std_dev = op.Reciprocal(std_dev)
61-
6266
# Normalize: Normalized = Mul(D, InvStdDev)
63-
normalized = op.Mul(deviation, inv_std_dev)
67+
68+
inv_std_dev = op.Reciprocal(std_dev)
69+
normalized = pattern.OrValue([
70+
op.Mul(deviation, inv_std_dev),
71+
op.Div(deviation, std_dev)
72+
])
6473

6574
# Scale: NormalizedScaled = Mul(Normalized, Scale)
6675
normalized_scaled = op.Mul(normalized, scale)
6776

6877
# Add bias (if present): Y = Add(NormalizedScaled, B)
69-
if bias is not None:
78+
79+
if self._has_bias:
7080
return op.Add(normalized_scaled, bias)
7181
else:
7282
return normalized_scaled
7383

74-
def check(
75-
self, op, x, scale, bias, epsilon, compute_dtype, target_dtype, **_
76-
) -> pattern.MatchResult: # type: ignore[name-defined]
84+
def check(self, op, x, scale, bias, epsilon, target_dtype, **_) -> pattern.MatchResult: # type: ignore[name-defined]
7785
"""Check if the pattern matches conditions for use of LayerNormalization op."""
7886
check_result = pattern.MatchResult()
7987

onnxscript/rewriter/layer_normalization_test.py renamed to onnxscript/rewriter/onnx_fusions/_layer_norm_test.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import onnxscript.rewriter.testing
1010
from onnxscript import FLOAT, OnnxFunction, script
1111
from onnxscript import opset18 as op
12-
from onnxscript.rewriter.layer_normalization import fuse_layer_normalization
12+
from onnxscript.rewriter.onnx_fusions._layer_norm import fuse_layer_normalization
1313

1414

1515
@script()
@@ -86,13 +86,7 @@ def _test_layer_norm_with_bias(
8686

8787

8888
class LayerNormFusionTest(unittest.TestCase):
89-
def _check(
90-
self,
91-
test_script: OnnxFunction,
92-
expected_graph_len: int,
93-
expected_op_type: str,
94-
has_bias: bool = False,
95-
):
89+
def _check(self, test_script: OnnxFunction):
9690
"""Helper method to run a fusion test scenario."""
9791
model_proto = test_script.to_model_proto()
9892
# Create test inputs
@@ -112,11 +106,11 @@ def _check(
112106

113107
def test_layer_norm_fusion_without_bias(self):
114108
"""Test LayerNorm fusion without bias."""
115-
self._check(_test_layer_norm_without_bias, 1, "LayerNormalization", has_bias=False)
109+
self._check(_test_layer_norm_without_bias)
116110

117111
def test_layer_norm_fusion_with_bias(self):
118112
"""Test LayerNorm fusion with bias."""
119-
self._check(_test_layer_norm_with_bias, 1, "LayerNormalization", has_bias=True)
113+
self._check(_test_layer_norm_with_bias)
120114

121115

122116
if __name__ == "__main__":

onnxscript/rewriter/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from onnxscript import ir
1212

1313

14-
def generate_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
14+
def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]:
1515
feeds: dict[str, Any] = {}
1616
for input in model.graph.input:
1717
input_type = input.type.tensor_type

0 commit comments

Comments
 (0)