Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,10 +1480,11 @@ def call_function( # pyre-fixme[14]
# pyre-ignore
return self._emit_free(args[0])

elif target is torch.ops.higher_order.cond:
return self._emit_control_flow(target, args, kwargs)

elif target is torch.ops.higher_order.map_impl:
elif target in (
torch.ops.higher_order.cond,
torch.ops.higher_order.map_impl,
torch.ops.higher_order.while_loop,
):
return self._emit_control_flow(target, args, kwargs)

elif target == executorch_call_delegate:
Expand Down
2 changes: 1 addition & 1 deletion exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
if target == torch.ops.higher_order.map_impl:
self.call(get_submodule(node.args[0]))
continue
elif target == control_flow.while_loop:
elif target == torch.ops.higher_order.while_loop:
self.call(get_submodule(node.args[0]))
self.call(get_submodule(node.args[1]))
continue
Expand Down
17 changes: 16 additions & 1 deletion exir/passes/spec_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

# pyre-ignore
def make_spec(x):
if isinstance(x, torch.Tensor):
if isinstance(x, ProxyValue):
return make_spec(x.node.meta["val"])
elif isinstance(x, torch.Tensor):
return TensorSpec.from_tensor(x)
elif isinstance(x, (int, bool, float)):
return x
Expand Down Expand Up @@ -109,6 +111,19 @@ def call_cond(self, pred, true_fn, false_fn, inputs, meta):
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
return super().call_cond(pred, true_fn, false_fn, inputs, meta)

def call_while(
self,
cond_fn: torch.fx.GraphModule,
body_fn: torch.fx.GraphModule,
carried_inputs: List[ProxyValue],
additional_inputs: List[ProxyValue],
meta: NodeMetadata,
):
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
return super().call_while(
cond_fn, body_fn, carried_inputs, additional_inputs, meta
)

def call_map(
self,
f: torch.fx.GraphModule,
Expand Down
9 changes: 7 additions & 2 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
NonDecompTestPartitioner,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.error import ExportError
from executorch.exir.error import ExportError, InternalError
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes import MemoryPlanningPass
Expand Down Expand Up @@ -313,7 +313,12 @@ def body_fn(it, x):
# Instantiate and export
inp = (torch.tensor(3), torch.randn(2, 2))
exported = export(M(), inp)
to_edge(exported)
ep = to_edge(exported)
# TODO(jakeszwe)
with self.assertRaisesRegex(
InternalError, "Unsupported control flow operator: while_loop"
):
ep.to_executorch()

def test_constraint_present_after_dce(self):
import executorch.exir as exir
Expand Down
Loading