Skip to content

Commit 023d392

Browse files
committed
Move cpu tensors over to numpy
1 parent a570dbf commit 023d392

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytensor/link/pytorch/linker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ def input_filter(self, inp: Any) -> Any:
1313
return pytorch_typify(inp)
1414

1515
def output_filter(self, var: Variable, out: Any) -> Any:
16-
return out.cpu()
16+
from torch import is_tensor
17+
18+
if is_tensor(out) and out.device.type == "cpu":
19+
return out.detach().numpy()
20+
else:
21+
return out
1722

1823
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1924
from pytensor.link.pytorch.dispatch import pytorch_funcify

0 commit comments

Comments
 (0)