We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f0507d5 commit c894b08Copy full SHA for c894b08
pytensor/link/pytorch/linker.py
@@ -1,7 +1,5 @@
1
from typing import Any
2
3
-from torch import is_tensor
4
-
5
from pytensor.graph.basic import Variable
6
from pytensor.link.basic import JITLinker
7
@@ -15,6 +13,8 @@ def input_filter(self, inp: Any) -> Any:
15
13
return pytorch_typify(inp)
16
14
17
def output_filter(self, var: Variable, out: Any) -> Any:
+ from torch import is_tensor
+
18
if is_tensor(out):
19
return out.cpu()
20
else:
0 commit comments