We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent abfa9c5 commit 5fbd0ddCopy full SHA for 5fbd0dd
graph_net/torch/utils.py
@@ -17,6 +17,8 @@ def apply_templates(forward_code: str) -> str:
17
imports = "import torch"
18
if "device" in forward_code:
19
imports += "\n\nfrom torch import device"
20
+ if "inf" in forward_code:
21
+ imports += "\n\nfrom torch import inf"
22
return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
23
24
0 commit comments