diff --git a/graph_net/torch/utils.py b/graph_net/torch/utils.py index 8419e5d95..f8df92090 100644 --- a/graph_net/torch/utils.py +++ b/graph_net/torch/utils.py @@ -8,17 +8,23 @@ import argparse import importlib import inspect +import math def apply_templates(forward_code: str) -> str: tab = " " forward_code = f"\n{tab}".join(forward_code.split("\n")) - return f"import torch\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}" + imports = "import torch" + if "device" in forward_code: + imports += "\n\nfrom torch import device" + return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}" def get_limited_precision_float_str(value): if not isinstance(value, float): return value + if not math.isfinite(value): + return f'float("{value}")' return f"{value:.3f}"