From e45dacb5c378547fa981ec8cec43201e6cc89bb2 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Mon, 21 Jul 2025 21:44:39 -0700 Subject: [PATCH] migrate exir test to use torch.cond instead of the hop directly (#12695) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/12695 torch.cond is the use-facing api, which allows closure, have safety checks etc. Differential Revision: D78696006 --- exir/backend/test/test_backends.py | 2 +- exir/tests/control_flow_models.py | 12 +++--------- exir/tests/test_passes.py | 4 +--- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index 7576fed8eb2..544b97bb53c 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -1033,7 +1033,7 @@ def false_fn(x, y): def f(x, y): x = x + y - x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) + x = torch.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) x = x - y return x diff --git a/exir/tests/control_flow_models.py b/exir/tests/control_flow_models.py index 5aab85cc45a..3c0fd8badab 100644 --- a/exir/tests/control_flow_models.py +++ b/exir/tests/control_flow_models.py @@ -20,9 +20,7 @@ def true_branch(x): def false_branch(x): return x * x - return torch.ops.higher_order.cond( - inp.sum() > 4, true_branch, false_branch, [inp] - ) + return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp]) def get_random_inputs(self): return (torch.rand(5),) @@ -39,9 +37,7 @@ def true_branch(x): def false_branch(x): return x * x * x - return torch.ops.higher_order.cond( - inp.sum() > 4, true_branch, false_branch, [inp] - ) + return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp]) def get_upper_bound_inputs(self): return (torch.rand(8),) @@ -72,9 +68,7 @@ def true_branch(x): def false_branch(x): return x * 2 - return torch.ops.higher_order.cond( - inp.sum() > 4, true_branch, false_branch, [inp] - ) + return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp]) def get_random_inputs(self): return (torch.eye(5) * 2,) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 422b133f0e0..70a4c88e3b6 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1463,9 +1463,7 @@ def forward(self, pred, x): out = torch.nn.functional.linear( x, self.w.to(torch.float16).to(torch.float32) ) - return torch.ops.higher_order.cond( - pred, self.true_fn, self.false_fn, [out] - ) + return torch.cond(pred, self.true_fn, self.false_fn, [out]) mod = Module() x = torch.randn([3, 3])