Skip to content

Commit 5cbb4d4

Browse files
committed
Minor fixes
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 08dba74 commit 5cbb4d4

File tree

3 files changed

+63
-66
lines changed

3 files changed

+63
-66
lines changed

onnxscript/rewriter/layer_normalization.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010
Layer Normalization fusion optimization.
1111
12-
This module contains rewrite rules for fusing Layer Normalization patterns into the
12+
This module contains rewrite rules for fusing Layer Normalization patterns into the
1313
ONNX LayerNormalization operator.
1414
1515
Layer Normalization performs normalization over the last D dimensions as specified by the axis.
@@ -34,37 +34,37 @@
3434

3535

3636
class LayerNormFusion(pattern.RewriteRuleClassBase):
37-
def pattern(self, op, x, scale, bias, epsilon, target_dtype):
37+
def pattern(self, op, x, scale, bias, epsilon, target_dtype):
3838
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
3939
# TODO: support axes attribute too
4040
mean = op.ReduceMean(x, [-1], keepdims=1)
41-
41+
4242
# Compute deviation: D = Sub(X, Mean)
4343
deviation = op.Sub(x, mean)
44-
44+
4545
# Compute squared deviation: DD = Mul(D, D)
4646
# TODO: support Pow (D, 2) as well
4747
deviation_squared = op.Mul(deviation, deviation)
48-
48+
4949
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
5050
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
51-
51+
5252
# Add epsilon: VarEps = Add(Var, epsilon)
5353
variance_plus_epsilon = op.Add(variance, epsilon)
54-
54+
5555
# Compute standard deviation: StdDev = Sqrt(VarEps)
5656
std_dev = op.Sqrt(variance_plus_epsilon)
57-
57+
5858
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
5959
# TODO: support Div(deviation, std_dev) as well?
6060
inv_std_dev = op.Reciprocal(std_dev)
61-
61+
6262
# Normalize: Normalized = Mul(D, InvStdDev)
6363
normalized = op.Mul(deviation, inv_std_dev)
64-
64+
6565
# Scale: NormalizedScaled = Mul(Normalized, Scale)
6666
normalized_scaled = op.Mul(normalized, scale)
67-
67+
6868
# Add bias (if present): Y = Add(NormalizedScaled, B)
6969
if bias is not None:
7070
return op.Add(normalized_scaled, bias)
@@ -76,17 +76,17 @@ def check(
7676
) -> pattern.MatchResult: # type: ignore[name-defined]
7777
"""Check if the pattern matches conditions for use of LayerNormalization op."""
7878
check_result = pattern.MatchResult()
79-
79+
8080
# epsilon must be a scalar
8181
epsilon_value = _ir_utils.get_singleton_value(epsilon)
8282
if not isinstance(epsilon_value, float): # TODO: support other types
8383
return check_result.fail("Epsilon is not a float value.", epsilon)
84-
84+
8585
if x.dtype not in fp_float_types:
8686
return check_result.fail("Input is not a float type.", x)
87-
87+
8888
self._stash_dtype = x.dtype
89-
89+
9090
return check_result
9191

9292
def rewrite(self, op, x, scale, bias, epsilon, **_):

onnxscript/rewriter/layer_normalization_test.py

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,120 +3,112 @@
33

44
import unittest
55

6-
import numpy as np
76
import onnx_ir as ir
8-
import parameterized
97

108
import onnxscript
11-
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
9+
import onnxscript.rewriter.testing
1210
from onnxscript import FLOAT, OnnxFunction, script
13-
from onnxscript import opset17 as op
14-
from onnxscript.optimizer import optimize, remove_unused_nodes
11+
from onnxscript import opset18 as op
1512
from onnxscript.rewriter.layer_normalization import fuse_layer_normalization
16-
import onnxscript.rewriter.testing
1713

1814

1915
@script()
2016
def _test_layer_norm_without_bias(x: FLOAT[2, 4, 8], scale: FLOAT[8]) -> FLOAT[2, 4, 8]:
2117
"""LayerNorm pattern without bias."""
2218
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
2319
mean = op.ReduceMean(x, [-1], keepdims=1)
24-
20+
2521
# Compute deviation: D = Sub(X, Mean)
2622
deviation = op.Sub(x, mean)
27-
23+
2824
# Compute squared deviation: DD = Mul(D, D)
2925
deviation_squared = op.Mul(deviation, deviation)
30-
26+
3127
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
3228
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
33-
29+
3430
# Add epsilon: VarEps = Add(Var, epsilon)
3531
epsilon = op.Constant(value_float=1e-5)
3632
variance_plus_epsilon = op.Add(variance, epsilon)
37-
33+
3834
# Compute standard deviation: StdDev = Sqrt(VarEps)
3935
std_dev = op.Sqrt(variance_plus_epsilon)
40-
36+
4137
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
4238
inv_std_dev = op.Reciprocal(std_dev)
43-
39+
4440
# Normalize: Normalized = Mul(D, InvStdDev)
4541
normalized = op.Mul(deviation, inv_std_dev)
46-
42+
4743
# Scale: NormalizedScaled = Mul(Normalized, Scale)
4844
normalized_scaled = op.Mul(normalized, scale)
49-
45+
5046
return normalized_scaled
5147

5248

5349
@script()
54-
def _test_layer_norm_with_bias(x: FLOAT[2, 4, 8], scale: FLOAT[8], bias: FLOAT[8]) -> FLOAT[2, 4, 8]:
50+
def _test_layer_norm_with_bias(
51+
x: FLOAT[2, 4, 8], scale: FLOAT[8], bias: FLOAT[8]
52+
) -> FLOAT[2, 4, 8]:
5553
"""LayerNorm pattern with bias."""
5654
# Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
5755
mean = op.ReduceMean(x, [-1], keepdims=1)
58-
56+
5957
# Compute deviation: D = Sub(X, Mean)
6058
deviation = op.Sub(x, mean)
61-
59+
6260
# Compute squared deviation: DD = Mul(D, D)
6361
deviation_squared = op.Mul(deviation, deviation)
64-
62+
6563
# Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
6664
variance = op.ReduceMean(deviation_squared, [-1], keepdims=1)
67-
65+
6866
# Add epsilon: VarEps = Add(Var, epsilon)
6967
epsilon = op.Constant(value_float=1e-5)
7068
variance_plus_epsilon = op.Add(variance, epsilon)
71-
69+
7270
# Compute standard deviation: StdDev = Sqrt(VarEps)
7371
std_dev = op.Sqrt(variance_plus_epsilon)
74-
72+
7573
# Compute reciprocal: InvStdDev = Reciprocal(StdDev)
7674
inv_std_dev = op.Reciprocal(std_dev)
77-
75+
7876
# Normalize: Normalized = Mul(D, InvStdDev)
7977
normalized = op.Mul(deviation, inv_std_dev)
80-
78+
8179
# Scale: NormalizedScaled = Mul(Normalized, Scale)
8280
normalized_scaled = op.Mul(normalized, scale)
83-
81+
8482
# Add bias: Y = Add(NormalizedScaled, B)
8583
result = op.Add(normalized_scaled, bias)
86-
84+
8785
return result
8886

8987

9088
class LayerNormFusionTest(unittest.TestCase):
9189
def _check(
9290
self,
93-
test_data_constructor: OnnxFunction,
91+
test_script: OnnxFunction,
9492
expected_graph_len: int,
9593
expected_op_type: str,
9694
has_bias: bool = False,
9795
):
9896
"""Helper method to run a fusion test scenario."""
99-
model_proto = test_data_constructor.to_model_proto()
97+
model_proto = test_script.to_model_proto()
10098
# Create test inputs
101-
input_data = onnxscript.rewriter.testing.generate_random_inputs(model)
99+
input_data = onnxscript.rewriter.testing.generate_random_inputs(model_proto)
102100

103101
model = ir.serde.deserialize_model(model_proto)
104102
fuse_layer_normalization(model)
105103

106-
# Run original model
107-
original_output = test_utils.ort_run("Original", model, input_data)
108-
109-
# Apply fusion
110-
fuse_layer_normalization(model)
111-
remove_unused_nodes(model)
104+
# Check that a LayerNormalization node was created
105+
self.assertIn("LayerNormalization", [n.op_type for n in model.graph])
112106

113-
# Verify fusion occurred
114-
self.assertEqual(len(model.graph), expected_graph_len)
115-
self.assertEqual(model.graph.node(0).op_type, expected_op_type)
107+
fused_model_proto = ir.serde.serialize_model(model)
116108

117-
# Run optimized model and verify outputs match
118-
optimized_output = test_utils.ort_run("Optimized", model, input_data)
119-
test_utils.assert_allclose(original_output, optimized_output, rtol=1e-4, atol=1e-4)
109+
onnxscript.rewriter.testing.assert_numerically_equal(
110+
model_proto, fused_model_proto, input_data
111+
)
120112

121113
def test_layer_norm_fusion_without_bias(self):
122114
"""Test LayerNorm fusion without bias."""
@@ -128,4 +120,4 @@ def test_layer_norm_fusion_with_bias(self):
128120

129121

130122
if __name__ == "__main__":
131-
unittest.main()
123+
unittest.main()

onnxscript/rewriter/testing.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010

1111
from onnxscript import ir
1212

13+
1314
def generate_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
1415
feeds: dict[str, Any] = {}
1516
for input in model.graph.input:
1617
input_type = input.type.tensor_type
1718
shape = tuple(input_type.shape.dim)
18-
if not all(hasattr(d, 'dim_value') for d in shape):
19+
if not all(hasattr(d, "dim_value") for d in shape):
1920
raise ValueError(f"Input {input.name} has dynamic shape dimensions.")
2021
shape = tuple(d.dim_value for d in shape)
2122
if input_type.elem_type == onnx.TensorProto.FLOAT:
@@ -28,11 +29,10 @@ def generate_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
2829
return feeds
2930

3031

31-
3232
def assert_numerically_equal(
3333
original_model_proto: onnx.ModelProto | ir.Model,
3434
rewritten_model_proto: onnx.ModelProto | ir.Model,
35-
args: tuple[Any, ...],
35+
args: tuple[Any, ...] | dict[str, Any],
3636
ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
3737
rtol: float = 1,
3838
atol: float = 1e-3,
@@ -53,9 +53,17 @@ def assert_numerically_equal(
5353
if isinstance(rewritten_model_proto, ir.Model):
5454
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto)
5555

56-
original_proto_ort_inputs = {
57-
k.name: v for k, v in zip(original_model_proto.graph.input, args)
58-
}
56+
if isinstance(args, dict):
57+
original_proto_ort_inputs = args
58+
the_rewritten_proto_ort_inputs = args
59+
else:
60+
original_proto_ort_inputs = {
61+
k.name: v for k, v in zip(original_model_proto.graph.input, args)
62+
}
63+
the_rewritten_proto_ort_inputs = {
64+
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
65+
}
66+
5967
original_proto_ort_inference_session = _ort_session_initializer(
6068
original_model_proto.SerializeToString(), ort_optimization_level
6169
)
@@ -65,9 +73,6 @@ def assert_numerically_equal(
6573
None, original_proto_ort_inputs, run_options=run_options
6674
)
6775

68-
the_rewritten_proto_ort_inputs = {
69-
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
70-
}
7176
the_rewritten_proto_ort_inference_session = _ort_session_initializer(
7277
rewritten_model_proto.SerializeToString(), ort_optimization_level
7378
)

0 commit comments

Comments
 (0)