Skip to content

Commit 0d24cdc

Browse files
author
Ian Schweer
committed
Add (and fix) test
1 parent 07e6113 commit 0d24cdc

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

pytensor/link/pytorch/linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __call__(self, *args, **kwargs):
7575

7676
# unset attrs
7777
for n, _ in self.gen_functors:
78-
if getattr(pytensor.link.utils, n, False):
78+
if getattr(pytensor.link.utils, n[1:], False):
7979
delattr(pytensor.link.utils, n[1:])
8080

8181
return res

tests/link/pytorch/test_basic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
torch = pytest.importorskip("torch")
25+
torch_dispatch = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")
2526

2627

2728
optimizer = RewriteDatabaseQuery(
@@ -343,3 +344,33 @@ def test_pytorch_OpFromGraph():
343344

344345
f = FunctionGraph([x, y, z], [out])
345346
compare_pytorch_and_py(f, [xv, yv, zv])
347+
348+
349+
def test_pytorch_link_references():
350+
import pytensor.link.utils as m
351+
352+
class BasicOp(Op):
353+
def __init__(self):
354+
super().__init__()
355+
356+
def make_node(self, *x):
357+
return Apply(self, list(x), [xi.type() for xi in x])
358+
359+
def perform(self, *_):
360+
raise RuntimeError("In perform")
361+
362+
@torch_dispatch.pytorch_funcify.register(BasicOp)
363+
def fn(op, node, **kwargs):
364+
def inner_fn(x):
365+
assert "inner_fn" in dir(m), "not available during dispatch"
366+
return x
367+
368+
return inner_fn
369+
370+
x = vector("x")
371+
op = BasicOp()
372+
out = op(x)
373+
374+
f = function([x], out, mode="PYTORCH")
375+
f(torch.ones(3))
376+
assert "inner_fn" not in dir(m), "function call reference leaked"

0 commit comments

Comments
 (0)