Skip to content

Commit 0fa73fd

Browse files
authored
Support while in pass base
Differential Revision: D77382557 Pull Request resolved: #12004
1 parent 55c8b8d commit 0fa73fd

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

exir/pass_base.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,11 @@ def call_function(
340340
elif target == torch.ops.higher_order.cond:
341341
pred, true_fn, false_fn, inputs = args
342342
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
343+
elif target == torch.ops.higher_order.while_loop:
344+
cond, body, carried_inputs, additional_inputs = args
345+
return self.callback.call_while(
346+
cond, body, carried_inputs, additional_inputs, meta
347+
)
343348
elif target == torch.ops.higher_order.map_impl:
344349
f, mapped_args, operands = args # type: ignore[assignment]
345350
return self.callback.call_map(f, mapped_args, operands, meta)
@@ -497,6 +502,31 @@ def call_cond(
497502
meta,
498503
)
499504

505+
def call_while(
506+
self,
507+
cond_fn: torch.fx.GraphModule,
508+
body_fn: torch.fx.GraphModule,
509+
carried_inputs: List[Argument],
510+
additional_inputs: List[Argument],
511+
meta: NodeMetadata,
512+
) -> ProxyValue:
513+
cond_fn = self.call_submodule(cond_fn, (*carried_inputs, *additional_inputs))
514+
body_fn = self.call_submodule(body_fn, (*carried_inputs, *additional_inputs))
515+
assert cond_fn is not None
516+
assert body_fn is not None
517+
return self._fx(
518+
"call_function",
519+
torch.ops.higher_order.while_loop,
520+
(
521+
cond_fn.graph_module,
522+
body_fn.graph_module,
523+
carried_inputs,
524+
additional_inputs,
525+
),
526+
{},
527+
meta,
528+
)
529+
500530
def call_map(
501531
self,
502532
f: torch.fx.GraphModule,

exir/program/test/test_program.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
294294
for node in ep.graph.nodes:
295295
self.assertNotEqual(node.op, "get_attr")
296296

297+
def test_while(self):
298+
class M(torch.nn.Module):
299+
def __init__(self) -> None:
300+
super().__init__()
301+
self.linear = torch.nn.Linear(2, 2)
302+
self.dec = torch.nn.Buffer(torch.tensor(1))
303+
304+
def forward(self, iter, x):
305+
def cond_fn(it, x):
306+
return it - self.dec > 0
307+
308+
def body_fn(it, x):
309+
return it - 1, self.linear(x)
310+
311+
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter, x))
312+
313+
# Instantiate and export
314+
inp = (torch.tensor(3), torch.randn(2, 2))
315+
exported = export(M(), inp)
316+
to_edge(exported)
317+
297318
def test_constraint_present_after_dce(self):
298319
import executorch.exir as exir
299320

0 commit comments

Comments
 (0)