Skip to content

Commit b06881a

Browse files
authored
[Bug Fix] Fix std when numel eq to 1 (#102)
1 parent a7f00f8 commit b06881a

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

graph_net/torch/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,19 @@ def apply_templates(forward_code: str) -> str:
2323
def get_limited_precision_float_str(value):
2424
if not isinstance(value, float):
2525
return value
26-
if not math.isfinite(value):
27-
return f'float("{value}")'
2826
return f"{value:.3f}"
2927

3028

3129
def convert_state_and_inputs_impl(state_dict, example_inputs):
3230
def tensor_info(tensor):
3331
is_float = tensor.dtype.is_floating_point
3432
mean = float(tensor.mean().item()) if is_float else None
35-
std = float(tensor.std().item()) if is_float else None
33+
std = None
34+
if is_float:
35+
if tensor.numel() <= 1:
36+
std = 0.0
37+
else:
38+
std = float(tensor.std().item())
3639
return {
3740
"shape": list(tensor.shape),
3841
"dtype": str(tensor.dtype),

0 commit comments

Comments
 (0)