Skip to content

Commit 052fdc2

Browse files
committed
restore pytorch
1 parent edacc0e commit 052fdc2

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pytensor/link/pytorch/linker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ def conversion_func_register(*args, **kwargs):
3838
)
3939

4040
def jit_compile(self, fn):
41-
import mlx.core as mx
41+
import torch
4242

43-
from pytensor.link.mlx.dispatch import mlx_typify
43+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
44+
45+
from pytensor.link.pytorch.dispatch import pytorch_typify
4446

4547
class wrapper:
4648
"""
@@ -54,7 +56,7 @@ class wrapper:
5456
"""
5557

5658
def __init__(self, fn, gen_functors):
57-
self.fn = mx.compile(fn)
59+
self.fn = torch.compile(fn)
5860
self.gen_functors = gen_functors.copy()
5961

6062
def __call__(self, *inputs, **kwargs):
@@ -65,7 +67,7 @@ def __call__(self, *inputs, **kwargs):
6567
setattr(pytensor.link.utils, n[1:], fn)
6668

6769
# Torch does not accept numpy inputs and may return GPU objects
68-
outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs)
70+
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
6971

7072
# unset attrs
7173
for n, _ in self.gen_functors:

0 commit comments

Comments
 (0)