Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OPT_O3,
OPT_STABILIZE,
OPT_UNSAFE,
PYTORCH,
AddDestroyHandler,
AddFeatureOptimizer,
Mode,
Expand Down
16 changes: 16 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from types import NoneType

import torch
import torch.compiler

from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
Expand Down Expand Up @@ -132,3 +135,16 @@ def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)

fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)
23 changes: 20 additions & 3 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -14,7 +15,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
from pytensor.tensor.type import matrices, matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -67,9 +68,9 @@ def compare_pytorch_and_py(

if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
assert_fn(j.detach().cpu().numpy(), p)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])

return pytensor_torch_fn, pytorch_res

Expand Down Expand Up @@ -301,3 +302,19 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
ofg_2 = OpFromGraph([x, y], [x * y, x - y])

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])
Loading