Skip to content

Commit 623dfbe

Browse files
author
Ian Schweer
committed
Only call .cpu when necessary
1 parent 714759c commit 623dfbe

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytensor/link/pytorch/linker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import copy
22
from typing import Any
33

4+
from torch import is_tensor
5+
46
from pytensor.graph.basic import Variable
57
from pytensor.link.basic import JITLinker
68
from pytensor.link.utils import unique_name_generator
@@ -19,7 +21,10 @@ def input_filter(self, inp: Any) -> Any:
1921
return pytorch_typify(inp)
2022

2123
def output_filter(self, var: Variable, out: Any) -> Any:
22-
return out.cpu()
24+
if is_tensor(out):
25+
return out.cpu()
26+
else:
27+
return out
2328

2429
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2530
from pytensor.link.pytorch.dispatch import pytorch_funcify

0 commit comments

Comments
 (0)