Skip to content

Commit 8a3bf6a

Browse files
committed
Remove gradient tracking
1 parent 6a39bcb commit 8a3bf6a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pytensor/link/pytorch/linker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class wrapper:
5151
"""
5252

5353
def __init__(self, fn, gen_functors):
54-
self.fn = torch.compile(fn)
54+
with torch.no_grad():
55+
self.fn = torch.compile(fn)
5556
self.gen_functors = gen_functors.copy()
5657

5758
def __call__(self, *inputs, **kwargs):
@@ -62,7 +63,9 @@ def __call__(self, *inputs, **kwargs):
6263
setattr(pytensor.link.utils, n[1:], fn)
6364

6465
# Torch does not accept numpy inputs and may return GPU objects
65-
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
66+
with torch.no_grad():
67+
ins = (pytorch_typify(inp) for inp in inputs)
68+
outs = self.fn(*ins, **kwargs)
6669

6770
# unset attrs
6871
for n, _ in self.gen_functors:

0 commit comments

Comments
 (0)