Skip to content

Commit ad04464

Browse files
author
Ian Schweer
committed
Only call .cpu when necessary
1 parent ebaf641 commit ad04464

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytensor/link/pytorch/linker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any
22

3+
from torch import is_tensor
4+
35
from pytensor.graph.basic import Variable
46
from pytensor.link.basic import JITLinker
57

@@ -13,7 +15,10 @@ def input_filter(self, inp: Any) -> Any:
1315
return pytorch_typify(inp)
1416

1517
def output_filter(self, var: Variable, out: Any) -> Any:
16-
return out.cpu()
18+
if is_tensor(out):
19+
return out.cpu()
20+
else:
21+
return out
1722

1823
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1924
from pytensor.link.pytorch.dispatch import pytorch_funcify

0 commit comments

Comments
 (0)