Skip to content

Commit 881bca1

Browse files
author
Ian Schweer
committed
Add torch ifelse
1 parent a3f0a4e commit 881bca1

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pytensor.compile.ops import DeepCopyOp
77
from pytensor.graph.fg import FunctionGraph
8+
from pytensor.ifelse import IfElse
89
from pytensor.link.utils import fgraph_to_python
910
from pytensor.raise_op import CheckAndRaise
1011
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
@@ -132,3 +133,16 @@ def makevector(*x):
132133
return torch.tensor(x, dtype=torch_dtype)
133134

134135
return makevector
136+
137+
138+
@pytorch_funcify.register(IfElse)
139+
def pytorch_funcify_IfElse(op, **kwargs):
140+
n_outs = op.n_outs
141+
142+
def ifelse(cond, *true_and_false, n_outs=n_outs):
143+
if cond:
144+
return torch.stack(true_and_false[:n_outs])
145+
else:
146+
return torch.stack(true_and_false[n_outs:])
147+
148+
return ifelse

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.op import Op
15+
from pytensor.ifelse import ifelse
1516
from pytensor.raise_op import CheckAndRaise
1617
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1718
from pytensor.tensor.type import matrix, scalar, vector
@@ -301,3 +302,17 @@ def test_pytorch_MakeVector():
301302
x_fg = FunctionGraph([], [x])
302303

303304
compare_pytorch_and_py(x_fg, [])
305+
306+
307+
def test_pytorch_ifelse():
308+
p1_vals = np.r_[1, 2, 3]
309+
p2_vals = np.r_[-1, -2, -3]
310+
311+
for test_value, cond in [(0.2, 0.5), (0.5, 0.4)]:
312+
a = scalar("a")
313+
x = ifelse(
314+
a < cond, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])
315+
)
316+
x_fg = FunctionGraph([a], x)
317+
318+
compare_pytorch_and_py(x_fg, np.array([test_value], dtype=config.floatX))

0 commit comments

Comments
 (0)