diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index 7576fed8eb2..544b97bb53c 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -1033,7 +1033,7 @@ def false_fn(x, y): def f(x, y): x = x + y - x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) + x = torch.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) x = x - y return x diff --git a/exir/tests/control_flow_models.py b/exir/tests/control_flow_models.py index 5aab85cc45a..3c0fd8badab 100644 --- a/exir/tests/control_flow_models.py +++ b/exir/tests/control_flow_models.py @@ -20,9 +20,7 @@ def true_branch(x): def false_branch(x): return x * x - return torch.ops.higher_order.cond( - inp.sum() > 4, true_branch, false_branch, [inp] - ) + return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp]) def get_random_inputs(self): return (torch.rand(5),) @@ -39,9 +37,7 @@ def true_branch(x): def false_branch(x): return x * x * x - return torch.ops.higher_order.cond( - inp.sum() > 4, true_branch, false_branch, [inp] - ) + return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp]) def get_upper_bound_inputs(self): return (torch.rand(8),) @@ -72,9 +68,7 @@ def true_branch(x): def false_branch(x): return x * 2 - return torch.ops.higher_order.cond( - inp.sum() > 4, true_branch, false_branch, [inp] - ) + return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp]) def get_random_inputs(self): return (torch.eye(5) * 2,) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 422b133f0e0..70a4c88e3b6 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1463,9 +1463,7 @@ def forward(self, pred, x): out = torch.nn.functional.linear( x, self.w.to(torch.float16).to(torch.float32) ) - return torch.ops.higher_order.cond( - pred, self.true_fn, self.false_fn, [out] - ) + return torch.cond(pred, self.true_fn, self.false_fn, [out]) mod = Module() x = torch.randn([3, 3])