33
44import unittest
55
6- import numpy as np
76import onnx_ir as ir
8- import parameterized
97
108import onnxscript
11- import onnxscript .rewriter .ort_fusions . _test_utils as test_utils
9+ import onnxscript .rewriter .testing
1210from onnxscript import FLOAT , OnnxFunction , script
13- from onnxscript import opset17 as op
14- from onnxscript .optimizer import optimize , remove_unused_nodes
11+ from onnxscript import opset18 as op
1512from onnxscript .rewriter .layer_normalization import fuse_layer_normalization
16- import onnxscript .rewriter .testing
1713
1814
1915@script ()
2016def _test_layer_norm_without_bias (x : FLOAT [2 , 4 , 8 ], scale : FLOAT [8 ]) -> FLOAT [2 , 4 , 8 ]:
2117 """LayerNorm pattern without bias."""
2218 # Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
2319 mean = op .ReduceMean (x , [- 1 ], keepdims = 1 )
24-
20+
2521 # Compute deviation: D = Sub(X, Mean)
2622 deviation = op .Sub (x , mean )
27-
23+
2824 # Compute squared deviation: DD = Mul(D, D)
2925 deviation_squared = op .Mul (deviation , deviation )
30-
26+
3127 # Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
3228 variance = op .ReduceMean (deviation_squared , [- 1 ], keepdims = 1 )
33-
29+
3430 # Add epsilon: VarEps = Add(Var, epsilon)
3531 epsilon = op .Constant (value_float = 1e-5 )
3632 variance_plus_epsilon = op .Add (variance , epsilon )
37-
33+
3834 # Compute standard deviation: StdDev = Sqrt(VarEps)
3935 std_dev = op .Sqrt (variance_plus_epsilon )
40-
36+
4137 # Compute reciprocal: InvStdDev = Reciprocal(StdDev)
4238 inv_std_dev = op .Reciprocal (std_dev )
43-
39+
4440 # Normalize: Normalized = Mul(D, InvStdDev)
4541 normalized = op .Mul (deviation , inv_std_dev )
46-
42+
4743 # Scale: NormalizedScaled = Mul(Normalized, Scale)
4844 normalized_scaled = op .Mul (normalized , scale )
49-
45+
5046 return normalized_scaled
5147
5248
5349@script ()
54- def _test_layer_norm_with_bias (x : FLOAT [2 , 4 , 8 ], scale : FLOAT [8 ], bias : FLOAT [8 ]) -> FLOAT [2 , 4 , 8 ]:
50+ def _test_layer_norm_with_bias (
51+ x : FLOAT [2 , 4 , 8 ], scale : FLOAT [8 ], bias : FLOAT [8 ]
52+ ) -> FLOAT [2 , 4 , 8 ]:
5553 """LayerNorm pattern with bias."""
5654 # Compute mean: Mean = ReduceMean(X, axes=normalized_axes)
5755 mean = op .ReduceMean (x , [- 1 ], keepdims = 1 )
58-
56+
5957 # Compute deviation: D = Sub(X, Mean)
6058 deviation = op .Sub (x , mean )
61-
59+
6260 # Compute squared deviation: DD = Mul(D, D)
6361 deviation_squared = op .Mul (deviation , deviation )
64-
62+
6563 # Compute variance: Var = ReduceMean(DD, axes=normalized_axes)
6664 variance = op .ReduceMean (deviation_squared , [- 1 ], keepdims = 1 )
67-
65+
6866 # Add epsilon: VarEps = Add(Var, epsilon)
6967 epsilon = op .Constant (value_float = 1e-5 )
7068 variance_plus_epsilon = op .Add (variance , epsilon )
71-
69+
7270 # Compute standard deviation: StdDev = Sqrt(VarEps)
7371 std_dev = op .Sqrt (variance_plus_epsilon )
74-
72+
7573 # Compute reciprocal: InvStdDev = Reciprocal(StdDev)
7674 inv_std_dev = op .Reciprocal (std_dev )
77-
75+
7876 # Normalize: Normalized = Mul(D, InvStdDev)
7977 normalized = op .Mul (deviation , inv_std_dev )
80-
78+
8179 # Scale: NormalizedScaled = Mul(Normalized, Scale)
8280 normalized_scaled = op .Mul (normalized , scale )
83-
81+
8482 # Add bias: Y = Add(NormalizedScaled, B)
8583 result = op .Add (normalized_scaled , bias )
86-
84+
8785 return result
8886
8987
9088class LayerNormFusionTest (unittest .TestCase ):
9189 def _check (
9290 self ,
93- test_data_constructor : OnnxFunction ,
91+ test_script : OnnxFunction ,
9492 expected_graph_len : int ,
9593 expected_op_type : str ,
9694 has_bias : bool = False ,
9795 ):
9896 """Helper method to run a fusion test scenario."""
99- model_proto = test_data_constructor .to_model_proto ()
97+ model_proto = test_script .to_model_proto ()
10098 # Create test inputs
101- input_data = onnxscript .rewriter .testing .generate_random_inputs (model )
99+ input_data = onnxscript .rewriter .testing .generate_random_inputs (model_proto )
102100
103101 model = ir .serde .deserialize_model (model_proto )
104102 fuse_layer_normalization (model )
105103
106- # Run original model
107- original_output = test_utils .ort_run ("Original" , model , input_data )
108-
109- # Apply fusion
110- fuse_layer_normalization (model )
111- remove_unused_nodes (model )
104+ # Check that a LayerNormalization node was created
105+ self .assertIn ("LayerNormalization" , [n .op_type for n in model .graph ])
112106
113- # Verify fusion occurred
114- self .assertEqual (len (model .graph ), expected_graph_len )
115- self .assertEqual (model .graph .node (0 ).op_type , expected_op_type )
107+ fused_model_proto = ir .serde .serialize_model (model )
116108
117- # Run optimized model and verify outputs match
118- optimized_output = test_utils . ort_run ( "Optimized" , model , input_data )
119- test_utils . assert_allclose ( original_output , optimized_output , rtol = 1e-4 , atol = 1e-4 )
109+ onnxscript . rewriter . testing . assert_numerically_equal (
110+ model_proto , fused_model_proto , input_data
111+ )
120112
121113 def test_layer_norm_fusion_without_bias (self ):
122114 """Test LayerNorm fusion without bias."""
@@ -128,4 +120,4 @@ def test_layer_norm_fusion_with_bias(self):
128120
129121
130122if __name__ == "__main__" :
131- unittest .main ()
123+ unittest .main ()
0 commit comments