Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnx_ir as ir

from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern

"""
Layer Normalization fusion optimization.

This module contains rewrite rules for fusing Layer Normalization patterns into the
ONNX LayerNormalization operator.

Layer Normalization performs normalization over the last D dimensions as specified by the axis.
The computation follows: Y = scale * (X - mean) / sqrt(variance + epsilon) + bias

Key points for the fusion optimization:
* Following restrictions from opset 17 LayerNormalization:
* Input, scale, and bias must be of same type T in {float16, bfloat16, float, double}
* The normalization can be done in a different precision than the input type (bfloat16 or float),
which is also the precision of the output mean/invstddev
"""

# input types permitted by LayerNormalization op (ONNX Opset 17)
LAYER_NORM_INPUT_TYPES = frozenset(
[
ir.DataType.FLOAT,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
ir.DataType.DOUBLE,
]
)

# Compute types permitted by LayerNormalization op (ONNX Opset 17), aka stash_type.
LAYER_NORM_COMPUTE_TYPES = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE])


class LayerNormFusion(pattern.RewriteRuleClassBase):
"""Fuse LayerNorm pattern into LayerNormalization op."""

def pattern(self, op, x, scale, epsilon):
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
# TODO: support axes attribute too
mean = op.ReduceMean(x, [-1], keepdims=1)

# Compute deviation: D = Sub(X, Mean)
deviation = op.Sub(x, mean)

# Compute squared deviation: DD = Mul(D, D)
deviation_squared = pattern.OrValue(
[
op.Mul(deviation, deviation),
op.Pow(deviation, 2),
]
)

# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)

# Add epsilon: VarEps = Add(Var, epsilon)
variance_plus_epsilon = op.Add(variance, epsilon)

# Compute standard deviation: StdDev = Sqrt(VarEps)
std_dev = op.Sqrt(variance_plus_epsilon)

# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
# Normalize: Normalized = Mul(D, InvStdDev)

inv_std_dev = op.Reciprocal(std_dev)
normalized = pattern.OrValue(
[op.Mul(deviation, inv_std_dev), op.Div(deviation, std_dev)]
)

# Scale: NormalizedScaled = Mul(Normalized, Scale)
normalized_scaled = op.Mul(normalized, scale)

return normalized_scaled

def check(self, context, x, epsilon, **_) -> pattern.MatchResult: # type: ignore[name-defined]
"""Check if the pattern matches conditions for use of LayerNormalization op."""
check_result = pattern.MatchResult()

# Type validation:
if x.dtype not in LAYER_NORM_COMPUTE_TYPES:
return check_result.fail("Input is not a float type.", x)
self._stash_type = x.dtype

# Check that epsilon is a scalar constant
epsilon_value = _ir_utils.get_singleton_value(epsilon)
if epsilon_value is None:
return check_result.fail("Epsilon is not a constant scalar.", epsilon)
# Epsilon is guaranteed to be same type as x (float or double, in this pattern)
self._epsilon = float(epsilon_value)

return check_result

def rewrite(self, op, x, scale, epsilon, **_):
return op.LayerNormalization(
x,
scale,
axis=-1,
epsilon=self._epsilon,
stash_type=self._stash_type,
)


class LayerNormBiasFusion(pattern.RewriteRuleClassBase):
"""Fuse LayerNorm => Add into LayerNorm with bias."""

def pattern(self, op, x, scale, bias):
return op.LayerNormalization(x, scale, _outputs=["normalized"]) + bias

def rewrite(self, op, x, scale, bias, normalized):
layernorm_node = normalized.producer()
attributes = layernorm_node.attributes
num_outputs = len(layernorm_node.outputs)
return op.LayerNormalization(x, scale, bias, _outputs=num_outputs, **attributes)


# Create rules for both with and without bias
_layer_norm_rule = LayerNormFusion.rule()
_layer_norm_with_bias_rule = LayerNormBiasFusion.rule()

layer_normalization_rules = [_layer_norm_rule, _layer_norm_with_bias_rule]
layer_normalization_ruleset = pattern.RewriteRuleSet(layer_normalization_rules)

fuse_layer_normalization = _fusion_utils.apply_fusion_rules(layer_normalization_ruleset)
120 changes: 120 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_layer_norm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest

import onnx_ir as ir

import onnxscript
import onnxscript.optimizer
import onnxscript.rewriter.testing
from onnxscript import FLOAT, OnnxFunction, script
from onnxscript import opset18 as op
from onnxscript.rewriter.onnx_fusions._layer_norm import fuse_layer_normalization


@script()
def _test_layer_norm_without_bias(x: FLOAT[2, 4, 8], scale: FLOAT[8]) -> FLOAT[2, 4, 8]:
"""LayerNorm pattern without bias."""
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
mean = op.ReduceMean(x, [-1], keepdims=1)

# Compute deviation: D = Sub(X, Mean)
deviation = op.Sub(x, mean)

# Compute squared deviation: DD = Mul(D, D)
deviation_squared = op.Mul(deviation, deviation)

# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)

# Add epsilon: VarEps = Add(Var, epsilon)
epsilon = op.Constant(value_float=1e-5)
variance_plus_epsilon = op.Add(variance, epsilon)

# Compute standard deviation: StdDev = Sqrt(VarEps)
std_dev = op.Sqrt(variance_plus_epsilon)

# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
inv_std_dev = op.Reciprocal(std_dev)

# Normalize: Normalized = Mul(D, InvStdDev)
normalized = op.Mul(deviation, inv_std_dev)

# Scale: NormalizedScaled = Mul(Normalized, Scale)
normalized_scaled = op.Mul(normalized, scale)

return normalized_scaled


@script()
def _test_layer_norm_with_bias(
x: FLOAT[2, 4, 8], scale: FLOAT[8], bias: FLOAT[8]
) -> FLOAT[2, 4, 8]:
"""LayerNorm pattern with bias."""
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
mean = op.ReduceMean(x, [-1], keepdims=1)

# Compute deviation: D = Sub(X, Mean)
deviation = op.Sub(x, mean)

# Compute squared deviation: DD = Mul(D, D)
deviation_squared = op.Mul(deviation, deviation)

# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)

# Add epsilon: VarEps = Add(Var, epsilon)
epsilon = op.Constant(value_float=1e-5)
variance_plus_epsilon = op.Add(variance, epsilon)

# Compute standard deviation: StdDev = Sqrt(VarEps)
std_dev = op.Sqrt(variance_plus_epsilon)

# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
inv_std_dev = op.Reciprocal(std_dev)

# Normalize: Normalized = Mul(D, InvStdDev)
normalized = op.Mul(deviation, inv_std_dev)

# Scale: NormalizedScaled = Mul(Normalized, Scale)
normalized_scaled = op.Mul(normalized, scale)

# Add bias: Y = Add(NormalizedScaled, B)
result = op.Add(normalized_scaled, bias)

return result


class LayerNormFusionTest(unittest.TestCase):
def _check(self, test_script: OnnxFunction):
"""Helper method to run a fusion test scenario."""
model_proto = test_script.to_model_proto()
# Create test inputs
input_data = onnxscript.rewriter.testing.generate_random_inputs(model_proto)

model = ir.serde.deserialize_model(model_proto)
fuse_layer_normalization(model)

onnxscript.optimizer.remove_unused_nodes(model)

# Check that a LayerNormalization node was created
self.assertEqual(["LayerNormalization"], [n.op_type for n in model.graph])

fused_model_proto = ir.serde.serialize_model(model)

onnxscript.rewriter.testing.assert_numerically_equal(
model_proto, fused_model_proto, input_data
)

def test_layer_norm_fusion_without_bias(self):
"""Test LayerNorm fusion without bias."""
self._check(_test_layer_norm_without_bias)

def test_layer_norm_fusion_with_bias(self):
"""Test LayerNorm fusion with bias."""
self._check(_test_layer_norm_with_bias)


if __name__ == "__main__":
unittest.main()
37 changes: 30 additions & 7 deletions onnxscript/rewriter/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,28 @@
from onnxscript import ir


def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]:
feeds: dict[str, Any] = {}
for input in model.graph.input:
input_type = input.type.tensor_type
shape = tuple(input_type.shape.dim)
if not all(hasattr(d, "dim_value") for d in shape):
raise ValueError(f"Input {input.name} has dynamic shape dimensions.")
shape = tuple(d.dim_value for d in shape)
if input_type.elem_type == onnx.TensorProto.FLOAT:
if shape:
feeds[input.name] = np.random.randn(*shape).astype(np.float32)
else:
feeds[input.name] = np.random.randn(1).astype(np.float32)
else:
raise ValueError(f"Not implemented for input type {input_type.elem_type}")
return feeds


def assert_numerically_equal(
original_model_proto: onnx.ModelProto | ir.Model,
rewritten_model_proto: onnx.ModelProto | ir.Model,
args: tuple[Any, ...],
args: tuple[Any, ...] | dict[str, Any],
ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
rtol: float = 1,
atol: float = 1e-3,
Expand All @@ -35,9 +53,17 @@ def assert_numerically_equal(
if isinstance(rewritten_model_proto, ir.Model):
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto)

original_proto_ort_inputs = {
k.name: v for k, v in zip(original_model_proto.graph.input, args)
}
if isinstance(args, dict):
original_proto_ort_inputs = args
the_rewritten_proto_ort_inputs = args
else:
original_proto_ort_inputs = {
k.name: v for k, v in zip(original_model_proto.graph.input, args)
}
the_rewritten_proto_ort_inputs = {
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
}

original_proto_ort_inference_session = _ort_session_initializer(
original_model_proto.SerializeToString(), ort_optimization_level
)
Expand All @@ -47,9 +73,6 @@ def assert_numerically_equal(
None, original_proto_ort_inputs, run_options=run_options
)

the_rewritten_proto_ort_inputs = {
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
}
the_rewritten_proto_ort_inference_session = _ort_session_initializer(
rewritten_model_proto.SerializeToString(), ort_optimization_level
)
Expand Down
Loading