Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OPT_O3,
OPT_STABILIZE,
OPT_UNSAFE,
PYTORCH,
AddDestroyHandler,
AddFeatureOptimizer,
Mode,
Expand Down
14 changes: 14 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from functools import singledispatch
from operator import itemgetter
from types import NoneType

import torch

from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
Expand Down Expand Up @@ -132,3 +134,15 @@ def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)

fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to compile the inner function? Is that a thing in PyTorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following what numba does where it jits the inner function - we could remove the inner torch.compile and just return op.fgraph if that seems more reasonable. That will still lead to some c-linker issues fwiw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the inner function, you only need to do indexing if the number of return values is more than 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba can only have inner compiled functions, I don't know if that's a requirement in pytorch, and whether it has any advantages. We don't do it for JAX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see / know of any requirement to have an inner compiled function.

return (
fgraph_fn
if len(op.fgraph.outputs) > 1
else lambda *args: itemgetter(0)(fgraph_fn(*args))
)
25 changes: 24 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -14,7 +15,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
from pytensor.tensor.type import matrices, matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -301,3 +302,25 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
OpFromGraph([x, y], [x * y, x - y])

# o1, o2 = ofg_2(y, z)
# out = ofg_1(x, o1) + o2

out = ofg_1(y, z)

xv = np.ones((2, 2), dtype=config.floatX)
np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = FunctionGraph([y, z], [out])
import pytensor.printing

pytensor.printing.debugprint(f)

compare_pytorch_and_py(f, [xv, zv])
Loading