Skip to content

Commit 48450b0

Browse files
author
Ian Schweer
committed
Fix test to allow for n_outs>1
1 parent a601a27 commit 48450b0

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ def makevector(*x):
139139
def pytorch_funcify_IfElse(op, **kwargs):
140140
n_outs = op.n_outs
141141

142-
def ifelse(cond, ifpath, elsepath, n_outs=n_outs):
142+
def ifelse(cond, *true_and_false, n_outs=n_outs):
143143
if cond:
144-
return ifpath
144+
return torch.stack(true_and_false[:n_outs])
145145
else:
146-
return elsepath
146+
return torch.stack(true_and_false[n_outs:])
147147

148148
return ifelse

tests/link/pytorch/test_basic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,15 @@ def test_pytorch_MakeVector():
305305

306306

307307
def test_pytorch_ifelse():
308-
true_vals = np.r_[1, 2, 3]
309-
false_vals = np.r_[-1, -2, -3]
308+
p1_vals = np.r_[1, 2, 3]
309+
p2_vals = np.r_[-1, -2, -3]
310310

311311
for test_value, cond in [(0.2, 0.5), (0.5, 0.4)]:
312312
a = scalar("a")
313313
a.tag.test_value = np.array(test_value, dtype=config.floatX)
314-
x = ifelse(a < cond, true_vals, false_vals)
315-
x_fg = FunctionGraph([a], [x]) # I.e. False
314+
x = ifelse(
315+
a < cond, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])
316+
)
317+
x_fg = FunctionGraph([a], x)
316318

317319
compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])

0 commit comments

Comments
 (0)