Skip to content

Commit d14c258

Browse files
author
Ian Schweer
committed
Make lifetime of closure explicit
1 parent 33a4d48 commit d14c258

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
import sys
13
from functools import singledispatch
24
from types import NoneType
35

@@ -56,9 +58,28 @@ def pytorch_funcify_FunctionGraph(
5658
fgraph_name="pytorch_funcified_fgraph",
5759
**kwargs,
5860
):
61+
def conversion_func_register(*args, **kwargs):
62+
last_frame = inspect.currentframe().f_back
63+
functor = pytorch_funcify(*args, **kwargs)
64+
if last_frame:
65+
module = sys.modules[last_frame.f_globals["__name__"]]
66+
name = functor.__name__
67+
i = 0
68+
prefix = ""
69+
while True:
70+
name = functor.__name__ + prefix
71+
if name in dir(module):
72+
i += 1
73+
prefix = f"_{i}"
74+
else:
75+
print("Registering functor", name)
76+
setattr(module, name, functor)
77+
break
78+
return functor
79+
5980
return fgraph_to_python(
6081
fgraph,
61-
pytorch_funcify,
82+
conversion_func_register,
6283
type_conversion_fn=pytorch_typify,
6384
fgraph_name=fgraph_name,
6485
**kwargs,
@@ -175,9 +196,7 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
175196
PYTORCH.optimizer(op.fgraph)
176197

177198
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
178-
# Disable one step inlining to prevent torch from trying to import local functions
179-
# defined in `pytorch_funcify`
180-
return torch.compiler.disable(fgraph_fn, recursive=False)
199+
return fgraph_fn
181200

182201

183202
@pytorch_funcify.register(TensorFromScalar)

pytensor/link/pytorch/dispatch/blockwise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
1616
for _ in range(batched_dims):
1717
inner_func = torch.vmap(inner_func)
1818

19-
@torch.compiler.disable(recursive=False)
2019
def batcher(*inputs):
2120
op._check_runtime_broadcast(node, inputs)
2221
# broadcast on batched_dims

tests/link/pytorch/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_pytorch_OpFromGraph():
335335
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
336336

337337
o1, o2 = ofg_2(y, z)
338-
out = ofg_1(x, o1) + o2
338+
out = ofg_1(x, o1) / o2
339339

340340
xv = np.ones((2, 2), dtype=config.floatX)
341341
yv = np.ones((2, 2), dtype=config.floatX) * 3

tests/link/pytorch/test_blockwise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def perform(self, *_):
2929

3030
@basic.pytorch_funcify.register(TestOp)
3131
def evaluate_test_op(op, **_):
32-
@torch.compiler.disable(recursive=False)
3332
def func(a, b):
3433
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
3534
return a @ b

0 commit comments

Comments
 (0)