From 6a39bcb813df626ab034619d20175ceb52a124b9 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sun, 15 Dec 2024 09:51:10 -0800 Subject: [PATCH 1/2] Clear warning message from torch --- pytensor/link/pytorch/dispatch/basic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..2dcf2dfc36 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -123,7 +123,10 @@ def arange(start, stop, step): def pytorch_funcify_Join(op, **kwargs): def join(axis, *tensors): # tensors could also be tuples, and in this case they don't have a ndim - tensors = [torch.tensor(tensor) for tensor in tensors] + tensors = [ + torch.tensor(tensor) if not torch.is_tensor(tensor) else tensor + for tensor in tensors + ] return torch.cat(tensors, dim=axis) From 8a3bf6ab3c029c158a4a29c2f2a485ef7618736a Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sun, 15 Dec 2024 09:51:58 -0800 Subject: [PATCH 2/2] Remove gradient tracking --- pytensor/link/pytorch/linker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..48675c5a4d 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -51,7 +51,8 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = torch.compile(fn) + with torch.no_grad(): + self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -62,7 +63,9 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) + with torch.no_grad(): + ins = (pytorch_typify(inp) for inp in inputs) + outs = self.fn(*ins, **kwargs) # unset attrs for n, _ in self.gen_functors: