|  | 
| 9 | 9 | import operator | 
| 10 | 10 | 
 | 
| 11 | 11 | import torch | 
|  | 12 | +from executorch.backends.arm._passes import ArmPass | 
| 12 | 13 | from executorch.backends.arm._passes.arm_pass_utils import create_node | 
| 13 | 14 | from executorch.exir.dialects._ops import ops as exir_ops | 
| 14 |  | -from executorch.exir.pass_base import ExportPass, PassResult | 
|  | 15 | +from executorch.exir.pass_base import PassResult | 
| 15 | 16 | 
 | 
| 16 | 17 | 
 | 
| 17 | 18 | def get_layer_norm_decomposition(op) -> tuple: | 
| @@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple: | 
| 40 | 41 |     raise RuntimeError(f"Can't get layer_norm composition for op {op}") | 
| 41 | 42 | 
 | 
| 42 | 43 | 
 | 
| 43 |  | -class DecomposeLayerNormPass(ExportPass): | 
|  | 44 | +class DecomposeLayerNormPass(ArmPass): | 
| 44 | 45 |     """ | 
| 45 | 46 |     layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias | 
| 46 | 47 |     Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: | 
| @@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule): | 
| 111 | 112 |                     var_op, | 
| 112 | 113 |                     args=(x, dims), | 
| 113 | 114 |                     kwargs={"correction": 0, "keepdim": keepdim}, | 
|  | 115 | +                    from_node=node, | 
| 114 | 116 |                 ) | 
| 115 | 117 |                 full = create_node( | 
| 116 | 118 |                     graph_module.graph, | 
| 117 | 119 |                     full_op, | 
| 118 | 120 |                     args=(epsilon_reshaped_shape, epsilon), | 
| 119 | 121 |                     kwargs={"dtype": dtype}, | 
|  | 122 | +                    from_node=node, | 
|  | 123 | +                ) | 
|  | 124 | +                add0 = create_node( | 
|  | 125 | +                    graph_module.graph, add_op, args=(var, full), from_node=node | 
|  | 126 | +                ) | 
|  | 127 | +                rsqrt = create_node( | 
|  | 128 | +                    graph_module.graph, rsqrt_op, args=(add0,), from_node=node | 
|  | 129 | +                ) | 
|  | 130 | +                mul0 = create_node( | 
|  | 131 | +                    graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node | 
| 120 | 132 |                 ) | 
| 121 |  | -                add0 = create_node(graph_module.graph, add_op, args=(var, full)) | 
| 122 |  | -                rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,)) | 
| 123 |  | -                mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt)) | 
| 124 | 133 |                 if weights is not None: | 
| 125 | 134 |                     weights_reshaped = create_node( | 
| 126 | 135 |                         graph_module.graph, | 
| 127 | 136 |                         view_op, | 
| 128 | 137 |                         args=(weights, weights_reshaped_shape), | 
|  | 138 | +                        from_node=node, | 
| 129 | 139 |                     ) | 
| 130 | 140 |                     mul1 = create_node( | 
| 131 |  | -                        graph_module.graph, mul_op, args=(mul0, weights_reshaped) | 
|  | 141 | +                        graph_module.graph, | 
|  | 142 | +                        mul_op, | 
|  | 143 | +                        args=( | 
|  | 144 | +                            mul0, | 
|  | 145 | +                            weights_reshaped, | 
|  | 146 | +                        ), | 
|  | 147 | +                        from_node=node, | 
| 132 | 148 |                     ) | 
| 133 | 149 |                 else: | 
| 134 | 150 |                     mul1 = mul0 | 
| 135 | 151 |                 output = mul1 | 
| 136 | 152 |                 if bias is not None: | 
| 137 | 153 |                     bias_reshaped_shape = weights_reshaped_shape | 
| 138 | 154 |                     bias_reshaped = create_node( | 
| 139 |  | -                        graph_module.graph, view_op, args=(bias, bias_reshaped_shape) | 
|  | 155 | +                        graph_module.graph, | 
|  | 156 | +                        view_op, | 
|  | 157 | +                        args=(bias, bias_reshaped_shape), | 
|  | 158 | +                        from_node=node, | 
| 140 | 159 |                     ) | 
| 141 | 160 |                     output = create_node( | 
| 142 |  | -                        graph_module.graph, add_op, args=(mul1, bias_reshaped) | 
|  | 161 | +                        graph_module.graph, | 
|  | 162 | +                        add_op, | 
|  | 163 | +                        args=(mul1, bias_reshaped), | 
|  | 164 | +                        from_node=node, | 
| 143 | 165 |                     ) | 
| 144 | 166 | 
 | 
| 145 | 167 |                 users = [user for user in node.users if node != user] | 
|  | 
0 commit comments