Skip to content

Commit 773c237

Browse files
committed
[Bugfix] fix avg_std_strf
1 parent 9b56b1e commit 773c237

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

ai_infra_bench/modes/gen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def gen_export_table(
8181
row_values.append(f"{row_results[0][feature]:.2f}")
8282
row_values.append("-")
8383
for metric in output_metrics:
84-
row_values.append(avg_std_strf(metric, row_results, precision=2))
84+
row_values.append(
85+
avg_std_strf(key=metric, item_list=row_results, precision=2)
86+
)
8587
lines.append("| " + " | ".join(row_values) + " |")
8688

8789
with open(os.path.join(output_dir, TABLE_NAME), mode="w", encoding="utf-8") as f:

ai_infra_bench/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import time
1212
from dataclasses import dataclass
1313
from functools import wraps
14-
from typing import Dict, List
14+
from typing import Any, Dict, List
1515

1616
import numpy as np
1717
import psutil
@@ -210,7 +210,7 @@ def read_jsonl(filepath: str):
210210

211211

212212
def avg_std_strf(
213-
key: str, item_list: List[Dict[str, float]], *, sep=", ", precision: int = None
213+
key: str, item_list: List[Dict[str, Any]], *, sep=", ", precision: int = None
214214
) -> str:
215215
val_list = [item[key] for item in item_list]
216216

@@ -219,7 +219,7 @@ def avg_std_strf(
219219
if not isinstance(val_list[0], (int, float)):
220220
return str(val_list[0])
221221

222-
if len(val_list) == 1 or (std := np.std(val_list, ddof=1)):
222+
if len(val_list) == 1 or (std := np.std(val_list, ddof=1)) == 0:
223223
return format(val_list[0], fmt)
224224

225225
avg = np.mean(val_list)

0 commit comments

Comments
 (0)