From 88ba29683f8881fcd307929bea277fe7701c4b1c Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Tue, 1 Jul 2025 12:37:46 -0700 Subject: [PATCH] Fixes in to_executorch for while (#12062) Summary: Fixes while loop support in executorch AOT up until emitter (https://github.com/pytorch/executorch/issues/8769#issuecomment-3012018917) Reviewed By: JacobSzwejbka Differential Revision: D77452951 --- exir/emit/_emitter.py | 9 +++++---- exir/passes/__init__.py | 2 +- exir/passes/spec_prop_pass.py | 17 ++++++++++++++++- exir/program/test/test_program.py | 9 +++++++-- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index fe18e49a623..5ee8ca56091 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1480,10 +1480,11 @@ def call_function( # pyre-fixme[14] # pyre-ignore return self._emit_free(args[0]) - elif target is torch.ops.higher_order.cond: - return self._emit_control_flow(target, args, kwargs) - - elif target is torch.ops.higher_order.map_impl: + elif target in ( + torch.ops.higher_order.cond, + torch.ops.higher_order.map_impl, + torch.ops.higher_order.while_loop, + ): return self._emit_control_flow(target, args, kwargs) elif target == executorch_call_delegate: diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 777b2a1c866..5c6eb63db46 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -340,7 +340,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: if target == torch.ops.higher_order.map_impl: self.call(get_submodule(node.args[0])) continue - elif target == control_flow.while_loop: + elif target == torch.ops.higher_order.while_loop: self.call(get_submodule(node.args[0])) self.call(get_submodule(node.args[1])) continue diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 25eb5beaa75..ab5367d1b20 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -19,7 +19,9 @@ # pyre-ignore def make_spec(x): - if isinstance(x, torch.Tensor): + if isinstance(x, ProxyValue): + return make_spec(x.node.meta["val"]) + elif isinstance(x, torch.Tensor): return TensorSpec.from_tensor(x) elif isinstance(x, (int, bool, float)): return x @@ -109,6 +111,19 @@ def call_cond(self, pred, true_fn, false_fn, inputs, meta): meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) return super().call_cond(pred, true_fn, false_fn, inputs, meta) + def call_while( + self, + cond_fn: torch.fx.GraphModule, + body_fn: torch.fx.GraphModule, + carried_inputs: List[ProxyValue], + additional_inputs: List[ProxyValue], + meta: NodeMetadata, + ): + meta["spec"] = pytree.tree_map(make_spec, carried_inputs) + return super().call_while( + cond_fn, body_fn, carried_inputs, additional_inputs, meta + ) + def call_map( self, f: torch.fx.GraphModule, diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 611e4b5f8a0..d5de78909ce 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -17,7 +17,7 @@ NonDecompTestPartitioner, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.error import ExportError +from executorch.exir.error import ExportError, InternalError from executorch.exir.lowered_backend_module import get_lowered_submodules from executorch.exir.pass_base import ExportPass from executorch.exir.passes import MemoryPlanningPass @@ -313,7 +313,12 @@ def body_fn(it, x): # Instantiate and export inp = (torch.tensor(3), torch.randn(2, 2)) exported = export(M(), inp) - to_edge(exported) + ep = to_edge(exported) + # TODO(jakeszwe) + with self.assertRaisesRegex( + InternalError, "Unsupported control flow operator: while_loop" + ): + ep.to_executorch() def test_constraint_present_after_dce(self): import executorch.exir as exir