1+ # Copyright (c) Microsoft Corporation.
2+ # Licensed under the MIT License.
3+
4+ import unittest
5+
6+ import numpy as np
7+ import onnx_ir as ir
8+ import parameterized
9+
10+ import onnxscript
11+ import onnxscript .rewriter .ort_fusions ._test_utils as test_utils
12+ from onnxscript import FLOAT , OnnxFunction , script
13+ from onnxscript import opset17 as op
14+ from onnxscript .optimizer import optimize , remove_unused_nodes
15+ from onnxscript .rewriter .layer_normalization import fuse_layer_normalization
16+ import onnxscript .rewriter .testing
17+
18+
19+ @script ()
20+ def _test_layer_norm_without_bias (x : FLOAT [2 , 4 , 8 ], scale : FLOAT [8 ]) -> FLOAT [2 , 4 , 8 ]:
21+ """LayerNorm pattern without bias."""
22+ # Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
23+ mean = op .ReduceMean (x , [- 1 ], keepdims = 1 )
24+
25+ # Compute deviation: D = Sub(X, Mean)
26+ deviation = op .Sub (x , mean )
27+
28+ # Compute squared deviation: DD = Mul(D, D)
29+ deviation_squared = op .Mul (deviation , deviation )
30+
31+ # Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
32+ variance = op .ReduceMean (deviation_squared , [- 1 ], keepdims = 1 )
33+
34+ # Add epsilon: VarEps = Add(Var, epsilon)
35+ epsilon = op .Constant (value_float = 1e-5 )
36+ variance_plus_epsilon = op .Add (variance , epsilon )
37+
38+ # Compute standard deviation: StdDev = Sqrt(VarEps)
39+ std_dev = op .Sqrt (variance_plus_epsilon )
40+
41+ # Compute reciprocal: InvStdDev = Reciprocal(StdDev)
42+ inv_std_dev = op .Reciprocal (std_dev )
43+
44+ # Normalize: Normalized = Mul(D, InvStdDev)
45+ normalized = op .Mul (deviation , inv_std_dev )
46+
47+ # Scale: NormalizedScaled = Mul(Normalized, Scale)
48+ normalized_scaled = op .Mul (normalized , scale )
49+
50+ return normalized_scaled
51+
52+
53+ @script ()
54+ def _test_layer_norm_with_bias (x : FLOAT [2 , 4 , 8 ], scale : FLOAT [8 ], bias : FLOAT [8 ]) -> FLOAT [2 , 4 , 8 ]:
55+ """LayerNorm pattern with bias."""
56+ # Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
57+ mean = op .ReduceMean (x , [- 1 ], keepdims = 1 )
58+
59+ # Compute deviation: D = Sub(X, Mean)
60+ deviation = op .Sub (x , mean )
61+
62+ # Compute squared deviation: DD = Mul(D, D)
63+ deviation_squared = op .Mul (deviation , deviation )
64+
65+ # Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
66+ variance = op .ReduceMean (deviation_squared , [- 1 ], keepdims = 1 )
67+
68+ # Add epsilon: VarEps = Add(Var, epsilon)
69+ epsilon = op .Constant (value_float = 1e-5 )
70+ variance_plus_epsilon = op .Add (variance , epsilon )
71+
72+ # Compute standard deviation: StdDev = Sqrt(VarEps)
73+ std_dev = op .Sqrt (variance_plus_epsilon )
74+
75+ # Compute reciprocal: InvStdDev = Reciprocal(StdDev)
76+ inv_std_dev = op .Reciprocal (std_dev )
77+
78+ # Normalize: Normalized = Mul(D, InvStdDev)
79+ normalized = op .Mul (deviation , inv_std_dev )
80+
81+ # Scale: NormalizedScaled = Mul(Normalized, Scale)
82+ normalized_scaled = op .Mul (normalized , scale )
83+
84+ # Add bias: Y = Add(NormalizedScaled, B)
85+ result = op .Add (normalized_scaled , bias )
86+
87+ return result
88+
89+
90+ class LayerNormFusionTest (unittest .TestCase ):
91+ def _check (
92+ self ,
93+ test_data_constructor : OnnxFunction ,
94+ expected_graph_len : int ,
95+ expected_op_type : str ,
96+ has_bias : bool = False ,
97+ ):
98+ """Helper method to run a fusion test scenario."""
99+ model_proto = test_data_constructor .to_model_proto ()
100+ # Create test inputs
101+ input_data = onnxscript .rewriter .testing .generate_random_inputs (model )
102+
103+ model = ir .serde .deserialize_model (model_proto )
104+ fuse_layer_normalization (model )
105+
106+ # Run original model
107+ original_output = test_utils .ort_run ("Original" , model , input_data )
108+
109+ # Apply fusion
110+ fuse_layer_normalization (model )
111+ remove_unused_nodes (model )
112+
113+ # Verify fusion occurred
114+ self .assertEqual (len (model .graph ), expected_graph_len )
115+ self .assertEqual (model .graph .node (0 ).op_type , expected_op_type )
116+
117+ # Run optimized model and verify outputs match
118+ optimized_output = test_utils .ort_run ("Optimized" , model , input_data )
119+ test_utils .assert_allclose (original_output , optimized_output , rtol = 1e-4 , atol = 1e-4 )
120+
121+ def test_layer_norm_fusion_without_bias (self ):
122+ """Test LayerNorm fusion without bias."""
123+ self ._check (_test_layer_norm_without_bias , 1 , "LayerNormalization" , has_bias = False )
124+
125+ def test_layer_norm_fusion_with_bias (self ):
126+ """Test LayerNorm fusion with bias."""
127+ self ._check (_test_layer_norm_with_bias , 1 , "LayerNormalization" , has_bias = True )
128+
129+
130+ if __name__ == "__main__" :
131+ unittest .main ()
0 commit comments