Skip to content

Commit a601a27

Browse files
author
Ian Schweer
committed
Update away from torch.where
1 parent bfb97ea commit a601a27

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,11 @@ def makevector(*x):
138138
@pytorch_funcify.register(IfElse)
139139
def pytorch_funcify_IfElse(op, **kwargs):
140140
n_outs = op.n_outs
141-
assert n_outs == 1
142141

143-
def ifelse(cond, *args, n_outs=n_outs):
144-
return torch.where(cond, *args)
142+
def ifelse(cond, ifpath, elsepath, n_outs=n_outs):
143+
if cond:
144+
return ifpath
145+
else:
146+
return elsepath
145147

146148
return ifelse

tests/link/pytorch/test_basic.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,10 @@ def test_pytorch_ifelse():
308308
true_vals = np.r_[1, 2, 3]
309309
false_vals = np.r_[-1, -2, -3]
310310

311-
x = ifelse(np.array(True), true_vals, false_vals)
312-
x_fg = FunctionGraph([], [x])
313-
314-
compare_pytorch_and_py(x_fg, [])
315-
316-
a = scalar("a")
317-
a.tag.test_value = np.array(0.2, dtype=config.floatX)
318-
x = ifelse(a < 0.5, true_vals, false_vals)
319-
x_fg = FunctionGraph([a], [x]) # I.e. False
311+
for test_value, cond in [(0.2, 0.5), (0.5, 0.4)]:
312+
a = scalar("a")
313+
a.tag.test_value = np.array(test_value, dtype=config.floatX)
314+
x = ifelse(a < cond, true_vals, false_vals)
315+
x_fg = FunctionGraph([a], [x]) # I.e. False
320316

321-
compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
317+
compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])

0 commit comments

Comments
 (0)