Skip to content

Commit 9d599c9

Browse files
authored
Fixes in to_executorch for while
Differential Revision: D77452951 Pull Request resolved: #12062
1 parent cb3b99a commit 9d599c9

File tree

4 files changed

+29
-8
lines changed

4 files changed

+29
-8
lines changed

exir/emit/_emitter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,10 +1480,11 @@ def call_function( # pyre-fixme[14]
14801480
# pyre-ignore
14811481
return self._emit_free(args[0])
14821482

1483-
elif target is torch.ops.higher_order.cond:
1484-
return self._emit_control_flow(target, args, kwargs)
1485-
1486-
elif target is torch.ops.higher_order.map_impl:
1483+
elif target in (
1484+
torch.ops.higher_order.cond,
1485+
torch.ops.higher_order.map_impl,
1486+
torch.ops.higher_order.while_loop,
1487+
):
14871488
return self._emit_control_flow(target, args, kwargs)
14881489

14891490
elif target == executorch_call_delegate:

exir/passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
340340
if target == torch.ops.higher_order.map_impl:
341341
self.call(get_submodule(node.args[0]))
342342
continue
343-
elif target == control_flow.while_loop:
343+
elif target == torch.ops.higher_order.while_loop:
344344
self.call(get_submodule(node.args[0]))
345345
self.call(get_submodule(node.args[1]))
346346
continue

exir/passes/spec_prop_pass.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
# pyre-ignore
2121
def make_spec(x):
22-
if isinstance(x, torch.Tensor):
22+
if isinstance(x, ProxyValue):
23+
return make_spec(x.node.meta["val"])
24+
elif isinstance(x, torch.Tensor):
2325
return TensorSpec.from_tensor(x)
2426
elif isinstance(x, (int, bool, float)):
2527
return x
@@ -109,6 +111,19 @@ def call_cond(self, pred, true_fn, false_fn, inputs, meta):
109111
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
110112
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
111113

114+
def call_while(
115+
self,
116+
cond_fn: torch.fx.GraphModule,
117+
body_fn: torch.fx.GraphModule,
118+
carried_inputs: List[ProxyValue],
119+
additional_inputs: List[ProxyValue],
120+
meta: NodeMetadata,
121+
):
122+
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
123+
return super().call_while(
124+
cond_fn, body_fn, carried_inputs, additional_inputs, meta
125+
)
126+
112127
def call_map(
113128
self,
114129
f: torch.fx.GraphModule,

exir/program/test/test_program.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
NonDecompTestPartitioner,
1818
)
1919
from executorch.exir.dialects._ops import ops as exir_ops
20-
from executorch.exir.error import ExportError
20+
from executorch.exir.error import ExportError, InternalError
2121
from executorch.exir.lowered_backend_module import get_lowered_submodules
2222
from executorch.exir.pass_base import ExportPass
2323
from executorch.exir.passes import MemoryPlanningPass
@@ -313,7 +313,12 @@ def body_fn(it, x):
313313
# Instantiate and export
314314
inp = (torch.tensor(3), torch.randn(2, 2))
315315
exported = export(M(), inp)
316-
to_edge(exported)
316+
ep = to_edge(exported)
317+
# TODO(jakeszwe)
318+
with self.assertRaisesRegex(
319+
InternalError, "Unsupported control flow operator: while_loop"
320+
):
321+
ep.to_executorch()
317322

318323
def test_constraint_present_after_dce(self):
319324
import executorch.exir as exir

0 commit comments

Comments
 (0)