We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3afe6b0 commit db39646Copy full SHA for db39646
graph_net/torch/utils.py
@@ -8,17 +8,23 @@
8
import argparse
9
import importlib
10
import inspect
11
+import math
12
13
14
def apply_templates(forward_code: str) -> str:
15
tab = " "
16
forward_code = f"\n{tab}".join(forward_code.split("\n"))
- return f"import torch\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
17
+ imports = "import torch"
18
+ if "device" in forward_code:
19
+ imports += "\n\nfrom torch import device"
20
+ return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
21
22
23
def get_limited_precision_float_str(value):
24
if not isinstance(value, float):
25
return value
26
+ if not math.isfinite(value):
27
+ return f'float("{value}")'
28
return f"{value:.3f}"
29
30
0 commit comments