Skip to content

Commit db39646

Browse files
committed
[Bug Fix] Fix when tensor with shape [1] and when device in gm.code
1 parent 3afe6b0 commit db39646

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

graph_net/torch/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,23 @@
88
import argparse
99
import importlib
1010
import inspect
11+
import math
1112

1213

1314
def apply_templates(forward_code: str) -> str:
1415
tab = " "
1516
forward_code = f"\n{tab}".join(forward_code.split("\n"))
16-
return f"import torch\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
17+
imports = "import torch"
18+
if "device" in forward_code:
19+
imports += "\n\nfrom torch import device"
20+
return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
1721

1822

1923
def get_limited_precision_float_str(value):
2024
if not isinstance(value, float):
2125
return value
26+
if not math.isfinite(value):
27+
return f'float("{value}")'
2228
return f"{value:.3f}"
2329

2430

0 commit comments

Comments
 (0)