diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index e2edcf0fe4..e0aa80e18b 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -9,6 +9,7 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph +from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise from pytensor.tensor.basic import ( @@ -153,6 +154,19 @@ def makevector(*x): return makevector +@pytorch_funcify.register(IfElse) +def pytorch_funcify_IfElse(op, **kwargs): + n_outs = op.n_outs + + def ifelse(cond, *true_and_false, n_outs=n_outs): + if cond: + return true_and_false[:n_outs] + else: + return true_and_false[n_outs:] + + return ifelse + + @pytorch_funcify.register(OpFromGraph) def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 1be74faf17..bb1958f43e 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -13,6 +13,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.ifelse import ifelse from pytensor.raise_op import CheckAndRaise from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -304,6 +305,23 @@ def test_pytorch_MakeVector(): compare_pytorch_and_py(x_fg, []) +def test_pytorch_ifelse(): + p1_vals = np.r_[1, 2, 3] + p2_vals = np.r_[-1, -2, -3] + + a = scalar("a") + x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])) + x_fg = FunctionGraph([a], x) + + compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX)) + + a = scalar("a") + x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])) + x_fg = FunctionGraph([a], x) + + compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX)) + + def test_pytorch_OpFromGraph(): x, y, z = matrices("xyz") ofg_1 = OpFromGraph([x, y], [x + y])