Skip to content

Commit c4b20ec

Browse files
author
Ian Schweer
committed
Basic support for makeop
1 parent 426931b commit c4b20ec

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

pytensor/compile/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
OPT_O3,
3131
OPT_STABILIZE,
3232
OPT_UNSAFE,
33+
PYTORCH,
3334
AddDestroyHandler,
3435
AddFeatureOptimizer,
3536
Mode,

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch
55

6+
from pytensor.compile import PYTORCH
7+
from pytensor.compile.builders import OpFromGraph
68
from pytensor.compile.ops import DeepCopyOp
79
from pytensor.graph.fg import FunctionGraph
810
from pytensor.link.utils import fgraph_to_python
@@ -132,3 +134,17 @@ def makevector(*x):
132134
return torch.tensor(x, dtype=torch_dtype)
133135

134136
return makevector
137+
138+
139+
@pytorch_funcify.register(OpFromGraph)
140+
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
141+
_ = kwargs.pop("storage_map", None)
142+
143+
PYTORCH.optimizer(op.fgraph)
144+
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
145+
146+
def opfromgraph(*inputs, dim=op.fgraph.outputs):
147+
res = fgraph_fn(*inputs)
148+
return res[0]
149+
150+
return opfromgraph

tests/link/pytorch/test_basic.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import pytensor.tensor.basic as ptb
8+
from pytensor.compile.builders import OpFromGraph
89
from pytensor.compile.function import function
910
from pytensor.compile.mode import get_mode
1011
from pytensor.compile.sharedvalue import SharedVariable, shared
@@ -14,7 +15,7 @@
1415
from pytensor.graph.op import Op
1516
from pytensor.raise_op import CheckAndRaise
1617
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
17-
from pytensor.tensor.type import matrix, scalar, vector
18+
from pytensor.tensor.type import matrices, matrix, scalar, vector
1819

1920

2021
torch = pytest.importorskip("torch")
@@ -301,3 +302,25 @@ def test_pytorch_MakeVector():
301302
x_fg = FunctionGraph([], [x])
302303

303304
compare_pytorch_and_py(x_fg, [])
305+
306+
307+
def test_pytorch_OpFromGraph():
308+
x, y, z = matrices("xyz")
309+
ofg_1 = OpFromGraph([x, y], [x + y])
310+
OpFromGraph([x, y], [x * y, x - y])
311+
312+
# o1, o2 = ofg_2(y, z)
313+
# out = ofg_1(x, o1) + o2
314+
315+
out = ofg_1(y, z)
316+
317+
xv = np.ones((2, 2), dtype=config.floatX)
318+
np.ones((2, 2), dtype=config.floatX) * 3
319+
zv = np.ones((2, 2), dtype=config.floatX) * 5
320+
321+
f = FunctionGraph([y, z], [out])
322+
import pytensor.printing
323+
324+
pytensor.printing.debugprint(f)
325+
326+
compare_pytorch_and_py(f, [xv, zv])

0 commit comments

Comments
 (0)