Skip to content

Commit d03bfc7

Browse files
author
Ian Schweer
committed
Remove compiler disable
1 parent 9b4059d commit d03bfc7

File tree

3 files changed

+1
-6
lines changed

3 files changed

+1
-6
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,7 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
181181
# Apply inner rewrites
182182
PYTORCH.optimizer(op.fgraph)
183183
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
184-
# Disable one step inlining to prevent torch from trying to import local functions
185-
# defined in `pytorch_funcify`
186-
return torch.compiler.disable(fgraph_fn, recursive=False)
184+
return fgraph_fn
187185

188186

189187
@pytorch_funcify.register(TensorFromScalar)

pytensor/link/pytorch/dispatch/blockwise.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.compiler
32

43
from pytensor.graph import FunctionGraph
54
from pytensor.link.pytorch.dispatch import pytorch_funcify
@@ -16,7 +15,6 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
1615
for _ in range(batched_dims):
1716
inner_func = torch.vmap(inner_func)
1817

19-
@torch.compiler.disable(recursive=False)
2018
def batcher(*inputs):
2119
op._check_runtime_broadcast(node, inputs)
2220
# broadcast on batched_dims

tests/link/pytorch/test_blockwise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def perform(self, *_):
2929

3030
@basic.pytorch_funcify.register(TestOp)
3131
def evaluate_test_op(op, **_):
32-
@torch.compiler.disable(recursive=False)
3332
def func(a, b):
3433
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
3534
return a @ b

0 commit comments

Comments
 (0)