Skip to content

Commit 68c6d89

Browse files
angelayifacebook-github-bot
authored andcommitted
Fixes in to_executorch for while (#12062)
Summary: Fixes while loop support in executorch AOT up until emitter (#8769 (comment)) Differential Revision: D77452951
1 parent b74c68d commit 68c6d89

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

exir/passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
340340
if target == torch.ops.higher_order.map_impl:
341341
self.call(get_submodule(node.args[0]))
342342
continue
343-
elif target == control_flow.while_loop:
343+
elif target == torch.ops.higher_order.while_loop:
344344
self.call(get_submodule(node.args[0]))
345345
self.call(get_submodule(node.args[1]))
346346
continue

exir/passes/spec_prop_pass.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
# pyre-ignore
2121
def make_spec(x):
22-
if isinstance(x, torch.Tensor):
22+
if isinstance(x, ProxyValue):
23+
return make_spec(x.node.meta["val"])
24+
elif isinstance(x, torch.Tensor):
2325
return TensorSpec.from_tensor(x)
2426
elif isinstance(x, (int, bool, float)):
2527
return x
@@ -109,6 +111,19 @@ def call_cond(self, pred, true_fn, false_fn, inputs, meta):
109111
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
110112
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
111113

114+
def call_while(
115+
self,
116+
cond_fn: torch.fx.GraphModule,
117+
body_fn: torch.fx.GraphModule,
118+
carried_inputs: List[ProxyValue],
119+
additional_inputs: List[ProxyValue],
120+
meta: NodeMetadata,
121+
):
122+
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
123+
return super().call_while(
124+
cond_fn, body_fn, carried_inputs, additional_inputs, meta
125+
)
126+
112127
def call_map(
113128
self,
114129
f: torch.fx.GraphModule,

exir/program/test/test_program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def body_fn(it, x):
314314
inp = (torch.tensor(3), torch.randn(2, 2))
315315
exported = export(M(), inp)
316316
to_edge(exported)
317+
# ep.to_executorch()
317318

318319
def test_constraint_present_after_dce(self):
319320
import executorch.exir as exir

0 commit comments

Comments
 (0)