Skip to content

Commit 08dba74

Browse files
committed
Add layernorm fusion
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 7407431 commit 08dba74

File tree

3 files changed

+268
-0
lines changed

3 files changed

+268
-0
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import onnx_ir as ir
6+
7+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
8+
9+
"""
10+
Layer Normalization fusion optimization.
11+
12+
This module contains rewrite rules for fusing Layer Normalization patterns into the
13+
ONNX LayerNormalization operator.
14+
15+
Layer Normalization performs normalization over the last D dimensions as specified by the axis.
16+
The computation follows: Y = scale * (X - mean) / sqrt(variance + epsilon) + bias
17+
18+
Key points for the fusion optimization:
19+
* Following restrictions from opset 17 LayerNormalization:
20+
* Input, scale, and bias must be of same type T in {float16, bfloat16, float, double}
21+
* The normalization can be done in a different precision than the input type (bfloat16 or float),
22+
which is also the precision of the output mean/invstddev
23+
"""
24+
25+
float_types = frozenset(
26+
[
27+
ir.DataType.FLOAT,
28+
ir.DataType.FLOAT16,
29+
ir.DataType.BFLOAT16,
30+
ir.DataType.DOUBLE,
31+
]
32+
)
33+
fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE])
34+
35+
36+
class LayerNormFusion(pattern.RewriteRuleClassBase):
37+
def pattern(self, op, x, scale, bias, epsilon, target_dtype):
38+
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
39+
# TODO: support axes attribute too
40+
mean = op.ReduceMean(x, [-1], keepdims=1)
41+
42+
# Compute deviation: D = Sub(X, Mean)
43+
deviation = op.Sub(x, mean)
44+
45+
# Compute squared deviation: DD = Mul(D, D)
46+
# TODO: support Pow (D, 2) as well
47+
deviation_squared = op.Mul(deviation, deviation)
48+
49+
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
50+
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
51+
52+
# Add epsilon: VarEps = Add(Var, epsilon)
53+
variance_plus_epsilon = op.Add(variance, epsilon)
54+
55+
# Compute standard deviation: StdDev = Sqrt(VarEps)
56+
std_dev = op.Sqrt(variance_plus_epsilon)
57+
58+
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
59+
# TODO: support Div(deviation, std_dev) as well?
60+
inv_std_dev = op.Reciprocal(std_dev)
61+
62+
# Normalize: Normalized = Mul(D, InvStdDev)
63+
normalized = op.Mul(deviation, inv_std_dev)
64+
65+
# Scale: NormalizedScaled = Mul(Normalized, Scale)
66+
normalized_scaled = op.Mul(normalized, scale)
67+
68+
# Add bias (if present): Y = Add(NormalizedScaled, B)
69+
if bias is not None:
70+
return op.Add(normalized_scaled, bias)
71+
else:
72+
return normalized_scaled
73+
74+
def check(
75+
self, op, x, scale, bias, epsilon, compute_dtype, target_dtype, **_
76+
) -> pattern.MatchResult: # type: ignore[name-defined]
77+
"""Check if the pattern matches conditions for use of LayerNormalization op."""
78+
check_result = pattern.MatchResult()
79+
80+
# epsilon must be a scalar
81+
epsilon_value = _ir_utils.get_singleton_value(epsilon)
82+
if not isinstance(epsilon_value, float): # TODO: support other types
83+
return check_result.fail("Epsilon is not a float value.", epsilon)
84+
85+
if x.dtype not in fp_float_types:
86+
return check_result.fail("Input is not a float type.", x)
87+
88+
self._stash_dtype = x.dtype
89+
90+
return check_result
91+
92+
def rewrite(self, op, x, scale, bias, epsilon, **_):
93+
if bias is not None:
94+
return op.LayerNormalization(
95+
x,
96+
scale,
97+
bias,
98+
axis=-1,
99+
epsilon=_ir_utils.get_singleton_value(epsilon),
100+
stash_type=self._stash_dtype,
101+
)
102+
else:
103+
return op.LayerNormalization(
104+
x,
105+
scale,
106+
axis=-1,
107+
epsilon=_ir_utils.get_singleton_value(epsilon),
108+
stash_type=self._stash_dtype,
109+
)
110+
111+
112+
# Create rules for both with and without bias
113+
_layer_norm_with_bias_rule = LayerNormFusion.rule("LayerNormWithBias", has_bias=True)
114+
_layer_norm_rule = LayerNormFusion.rule("LayerNorm", has_bias=False)
115+
116+
layer_normalization_rules = [_layer_norm_with_bias_rule, _layer_norm_rule]
117+
layer_normalization_ruleset = pattern.RewriteRuleSet(layer_normalization_rules)
118+
119+
fuse_layer_normalization = _fusion_utils.apply_fusion_rules(layer_normalization_ruleset)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import unittest
5+
6+
import numpy as np
7+
import onnx_ir as ir
8+
import parameterized
9+
10+
import onnxscript
11+
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
12+
from onnxscript import FLOAT, OnnxFunction, script
13+
from onnxscript import opset17 as op
14+
from onnxscript.optimizer import optimize, remove_unused_nodes
15+
from onnxscript.rewriter.layer_normalization import fuse_layer_normalization
16+
import onnxscript.rewriter.testing
17+
18+
19+
@script()
20+
def _test_layer_norm_without_bias(x: FLOAT[2, 4, 8], scale: FLOAT[8]) -> FLOAT[2, 4, 8]:
21+
"""LayerNorm pattern without bias."""
22+
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
23+
mean = op.ReduceMean(x, [-1], keepdims=1)
24+
25+
# Compute deviation: D = Sub(X, Mean)
26+
deviation = op.Sub(x, mean)
27+
28+
# Compute squared deviation: DD = Mul(D, D)
29+
deviation_squared = op.Mul(deviation, deviation)
30+
31+
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
32+
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
33+
34+
# Add epsilon: VarEps = Add(Var, epsilon)
35+
epsilon = op.Constant(value_float=1e-5)
36+
variance_plus_epsilon = op.Add(variance, epsilon)
37+
38+
# Compute standard deviation: StdDev = Sqrt(VarEps)
39+
std_dev = op.Sqrt(variance_plus_epsilon)
40+
41+
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
42+
inv_std_dev = op.Reciprocal(std_dev)
43+
44+
# Normalize: Normalized = Mul(D, InvStdDev)
45+
normalized = op.Mul(deviation, inv_std_dev)
46+
47+
# Scale: NormalizedScaled = Mul(Normalized, Scale)
48+
normalized_scaled = op.Mul(normalized, scale)
49+
50+
return normalized_scaled
51+
52+
53+
@script()
54+
def _test_layer_norm_with_bias(x: FLOAT[2, 4, 8], scale: FLOAT[8], bias: FLOAT[8]) -> FLOAT[2, 4, 8]:
55+
"""LayerNorm pattern with bias."""
56+
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
57+
mean = op.ReduceMean(x, [-1], keepdims=1)
58+
59+
# Compute deviation: D = Sub(X, Mean)
60+
deviation = op.Sub(x, mean)
61+
62+
# Compute squared deviation: DD = Mul(D, D)
63+
deviation_squared = op.Mul(deviation, deviation)
64+
65+
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
66+
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
67+
68+
# Add epsilon: VarEps = Add(Var, epsilon)
69+
epsilon = op.Constant(value_float=1e-5)
70+
variance_plus_epsilon = op.Add(variance, epsilon)
71+
72+
# Compute standard deviation: StdDev = Sqrt(VarEps)
73+
std_dev = op.Sqrt(variance_plus_epsilon)
74+
75+
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
76+
inv_std_dev = op.Reciprocal(std_dev)
77+
78+
# Normalize: Normalized = Mul(D, InvStdDev)
79+
normalized = op.Mul(deviation, inv_std_dev)
80+
81+
# Scale: NormalizedScaled = Mul(Normalized, Scale)
82+
normalized_scaled = op.Mul(normalized, scale)
83+
84+
# Add bias: Y = Add(NormalizedScaled, B)
85+
result = op.Add(normalized_scaled, bias)
86+
87+
return result
88+
89+
90+
class LayerNormFusionTest(unittest.TestCase):
91+
def _check(
92+
self,
93+
test_data_constructor: OnnxFunction,
94+
expected_graph_len: int,
95+
expected_op_type: str,
96+
has_bias: bool = False,
97+
):
98+
"""Helper method to run a fusion test scenario."""
99+
model_proto = test_data_constructor.to_model_proto()
100+
# Create test inputs
101+
input_data = onnxscript.rewriter.testing.generate_random_inputs(model)
102+
103+
model = ir.serde.deserialize_model(model_proto)
104+
fuse_layer_normalization(model)
105+
106+
# Run original model
107+
original_output = test_utils.ort_run("Original", model, input_data)
108+
109+
# Apply fusion
110+
fuse_layer_normalization(model)
111+
remove_unused_nodes(model)
112+
113+
# Verify fusion occurred
114+
self.assertEqual(len(model.graph), expected_graph_len)
115+
self.assertEqual(model.graph.node(0).op_type, expected_op_type)
116+
117+
# Run optimized model and verify outputs match
118+
optimized_output = test_utils.ort_run("Optimized", model, input_data)
119+
test_utils.assert_allclose(original_output, optimized_output, rtol=1e-4, atol=1e-4)
120+
121+
def test_layer_norm_fusion_without_bias(self):
122+
"""Test LayerNorm fusion without bias."""
123+
self._check(_test_layer_norm_without_bias, 1, "LayerNormalization", has_bias=False)
124+
125+
def test_layer_norm_fusion_with_bias(self):
126+
"""Test LayerNorm fusion with bias."""
127+
self._check(_test_layer_norm_with_bias, 1, "LayerNormalization", has_bias=True)
128+
129+
130+
if __name__ == "__main__":
131+
unittest.main()

onnxscript/rewriter/testing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010

1111
from onnxscript import ir
1212

13+
def generate_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
14+
feeds: dict[str, Any] = {}
15+
for input in model.graph.input:
16+
input_type = input.type.tensor_type
17+
shape = tuple(input_type.shape.dim)
18+
if not all(hasattr(d, 'dim_value') for d in shape):
19+
raise ValueError(f"Input {input.name} has dynamic shape dimensions.")
20+
shape = tuple(d.dim_value for d in shape)
21+
if input_type.elem_type == onnx.TensorProto.FLOAT:
22+
if shape:
23+
feeds[input.name] = np.random.randn(*shape).astype(np.float32)
24+
else:
25+
feeds[input.name] = np.random.randn(1).astype(np.float32)
26+
else:
27+
raise ValueError(f"Not implemented for input type {input_type.elem_type}")
28+
return feeds
29+
30+
1331

1432
def assert_numerically_equal(
1533
original_model_proto: onnx.ModelProto | ir.Model,

0 commit comments

Comments
 (0)