Skip to content

Commit a2afd8f

Browse files
ydwu4facebook-github-bot
authored andcommitted
migrate exir test to use torch.cond instead of the hop directly
Summary: torch.cond is the use-facing api, which allows closure, have safety checks etc. Differential Revision: D78696006
1 parent 76ed145 commit a2afd8f

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

exir/backend/test/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def false_fn(x, y):
10331033

10341034
def f(x, y):
10351035
x = x + y
1036-
x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1036+
x = torch.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
10371037
x = x - y
10381038
return x
10391039

exir/tests/control_flow_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def true_branch(x):
2020
def false_branch(x):
2121
return x * x
2222

23-
return torch.ops.higher_order.cond(
23+
return torch.cond(
2424
inp.sum() > 4, true_branch, false_branch, [inp]
2525
)
2626

@@ -39,7 +39,7 @@ def true_branch(x):
3939
def false_branch(x):
4040
return x * x * x
4141

42-
return torch.ops.higher_order.cond(
42+
return torch.cond(
4343
inp.sum() > 4, true_branch, false_branch, [inp]
4444
)
4545

@@ -72,7 +72,7 @@ def true_branch(x):
7272
def false_branch(x):
7373
return x * 2
7474

75-
return torch.ops.higher_order.cond(
75+
return torch.cond(
7676
inp.sum() > 4, true_branch, false_branch, [inp]
7777
)
7878

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,7 @@ def forward(self, pred, x):
14631463
out = torch.nn.functional.linear(
14641464
x, self.w.to(torch.float16).to(torch.float32)
14651465
)
1466-
return torch.ops.higher_order.cond(
1466+
return torch.cond(
14671467
pred, self.true_fn, self.false_fn, [out]
14681468
)
14691469

0 commit comments

Comments
 (0)