Skip to content

Commit 494bcc5

Browse files
author
Ian Schweer
committed
Smarter linker
1 parent 8b8e174 commit 494bcc5

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

pytensor/link/pytorch/linker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ def input_filter(self, inp: Any) -> Any:
1313
return pytorch_typify(inp)
1414

1515
def output_filter(self, var: Variable, out: Any) -> Any:
16-
return out.cpu()
16+
import torch
17+
18+
if torch.is_tensor(out):
19+
return out.cpu()
20+
return out
1721

1822
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1923
from pytensor.link.pytorch.dispatch import pytorch_funcify

0 commit comments

Comments
 (0)