Skip to content

Commit f8d65e9

Browse files
authored
Prevent decomposing RMSNorm in Jarvis (pytorch#10074)
Summary: D72276082 added support for RMSNorm in the Turing compiler, this diff does the same for Jarvis and removes redundant op registrations. Differential Revision: D72802423
1 parent c38e71f commit f8d65e9

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

backends/cadence/aot/compiler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
EdgeProgramManager,
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
34-
to_edge,
3534
)
3635
from executorch.exir.pass_base import PassResult
3736
from executorch.exir.passes import ToOutVarPass
3837
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
38+
from executorch.exir.program._program import to_edge_with_preserved_ops
3939
from torch._inductor.decomposition import remove_decompositions
4040
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4141

@@ -80,6 +80,7 @@ def convert_pt2(
8080
torch.ops.aten.layer_norm.default,
8181
torch.ops.aten.linear.default,
8282
torch.ops.aten.matmul.default,
83+
torch.ops.aten.rms_norm.default,
8384
]
8485
# Remove decompositions for the ops we want to keep
8586
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
@@ -201,9 +202,9 @@ def lower_ep_to_edge(
201202
"""
202203
Lower an ExportedProgram to an EdgeProgramManager (in edge IR).
203204
"""
204-
# Call to_edge to convert the graph to edge IR.
205+
# Call to_edge_with_preserved_ops to convert the graph to edge IR.
205206
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
206-
edge_prog_manager = to_edge(
207+
edge_prog_manager = to_edge_with_preserved_ops(
207208
expo_program,
208209
compile_config=EdgeCompileConfig(
209210
_skip_dim_order=True,
@@ -216,9 +217,11 @@ def lower_ep_to_edge(
216217
torch.ops.aten.linalg_vector_norm.default,
217218
torch.ops.aten.unfold.default,
218219
torch.ops.aten.angle.default,
220+
torch.ops.aten.rms_norm.default,
219221
],
220222
),
221223
constant_methods=constant_methods,
224+
preserve_ops=(torch.ops.aten.rms_norm.default,),
222225
)
223226

224227
if dump_graphs:

backends/cadence/aot/ops_registrations.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@
139139
"int in_zero_point, bool channel_last=False) -> (Tensor out)"
140140
)
141141
lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
142-
lib.define("rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)")
143142
lib.define(
144143
"transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
145144
"int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
@@ -211,9 +210,6 @@
211210
"fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)"
212211
)
213212
lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)")
214-
lib.define(
215-
"rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)"
216-
)
217213
lib.define(
218214
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
219215
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"

0 commit comments

Comments
 (0)