|
9 | 9 | import operator |
10 | 10 |
|
11 | 11 | import torch |
| 12 | +from executorch.backends.arm._passes import ArmPass |
12 | 13 | from executorch.backends.arm._passes.arm_pass_utils import create_node |
13 | 14 | from executorch.exir.dialects._ops import ops as exir_ops |
14 | | -from executorch.exir.pass_base import ExportPass, PassResult |
| 15 | +from executorch.exir.pass_base import PassResult |
15 | 16 |
|
16 | 17 |
|
17 | 18 | def get_layer_norm_decomposition(op) -> tuple: |
@@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple: |
40 | 41 | raise RuntimeError(f"Can't get layer_norm composition for op {op}") |
41 | 42 |
|
42 | 43 |
|
43 | | -class DecomposeLayerNormPass(ExportPass): |
| 44 | +class DecomposeLayerNormPass(ArmPass): |
44 | 45 | """ |
45 | 46 | layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias |
46 | 47 | Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: |
@@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule): |
111 | 112 | var_op, |
112 | 113 | args=(x, dims), |
113 | 114 | kwargs={"correction": 0, "keepdim": keepdim}, |
| 115 | + from_node=node, |
114 | 116 | ) |
115 | 117 | full = create_node( |
116 | 118 | graph_module.graph, |
117 | 119 | full_op, |
118 | 120 | args=(epsilon_reshaped_shape, epsilon), |
119 | 121 | kwargs={"dtype": dtype}, |
| 122 | + from_node=node, |
| 123 | + ) |
| 124 | + add0 = create_node( |
| 125 | + graph_module.graph, add_op, args=(var, full), from_node=node |
| 126 | + ) |
| 127 | + rsqrt = create_node( |
| 128 | + graph_module.graph, rsqrt_op, args=(add0,), from_node=node |
| 129 | + ) |
| 130 | + mul0 = create_node( |
| 131 | + graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node |
120 | 132 | ) |
121 | | - add0 = create_node(graph_module.graph, add_op, args=(var, full)) |
122 | | - rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,)) |
123 | | - mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt)) |
124 | 133 | if weights is not None: |
125 | 134 | weights_reshaped = create_node( |
126 | 135 | graph_module.graph, |
127 | 136 | view_op, |
128 | 137 | args=(weights, weights_reshaped_shape), |
| 138 | + from_node=node, |
129 | 139 | ) |
130 | 140 | mul1 = create_node( |
131 | | - graph_module.graph, mul_op, args=(mul0, weights_reshaped) |
| 141 | + graph_module.graph, |
| 142 | + mul_op, |
| 143 | + args=( |
| 144 | + mul0, |
| 145 | + weights_reshaped, |
| 146 | + ), |
| 147 | + from_node=node, |
132 | 148 | ) |
133 | 149 | else: |
134 | 150 | mul1 = mul0 |
135 | 151 | output = mul1 |
136 | 152 | if bias is not None: |
137 | 153 | bias_reshaped_shape = weights_reshaped_shape |
138 | 154 | bias_reshaped = create_node( |
139 | | - graph_module.graph, view_op, args=(bias, bias_reshaped_shape) |
| 155 | + graph_module.graph, |
| 156 | + view_op, |
| 157 | + args=(bias, bias_reshaped_shape), |
| 158 | + from_node=node, |
140 | 159 | ) |
141 | 160 | output = create_node( |
142 | | - graph_module.graph, add_op, args=(mul1, bias_reshaped) |
| 161 | + graph_module.graph, |
| 162 | + add_op, |
| 163 | + args=(mul1, bias_reshaped), |
| 164 | + from_node=node, |
143 | 165 | ) |
144 | 166 |
|
145 | 167 | users = [user for user in node.users if node != user] |
|
0 commit comments