Skip to content

Commit 8161e36

Browse files
committed
fixes
1 parent 8ddf57c commit 8161e36

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

benchmarks/benchmarking_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,12 @@ def run_benchmark(self, scenario: BenchmarkScenario):
112112
logger.info(f"Running scenario: {scenario.name}.")
113113
try:
114114
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
115-
num_params = round(calculate_params(model) / 1e6, 2)
116-
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2)
115+
num_params = round(calculate_params(model) / 1e9, 2)
116+
try:
117+
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
118+
except Exception as e:
119+
logger.info(f"Problem in calculating FLOPs:\n{e}")
120+
flops = None
117121
model.cpu()
118122
del model
119123
except Exception as e:
@@ -156,8 +160,8 @@ def run_benchmark(self, scenario: BenchmarkScenario):
156160
result = {
157161
"scenario": scenario.name,
158162
"model_cls": scenario.model_cls.__name__,
159-
"num_params_M": num_params,
160-
"flops_M": flops,
163+
"num_params_B": num_params,
164+
"flops_G": flops,
161165
"time_plain_s": plain["time"],
162166
"mem_plain_GB": plain["memory"],
163167
"time_compile_s": compiled["time"],

benchmarks/populate_into_db.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def _cast_value(val, dtype: str):
6060
for _, row in df.iterrows():
6161
scenario = _cast_value(row.get("scenario"), "text")
6262
model_cls = _cast_value(row.get("model_cls"), "text")
63-
num_params_M = _cast_value(row.get("num_params_M"), "float")
64-
flops_M = _cast_value(row.get("flops_M"), "float")
63+
num_params_B = _cast_value(row.get("num_params_B"), "float")
64+
flops_G = _cast_value(row.get("flops_G"), "float")
6565
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
6666
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
6767
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
@@ -84,8 +84,8 @@ def _cast_value(val, dtype: str):
8484
"repository": "huggingface/diffusers",
8585
"scenario": scenario,
8686
"model_cls": model_cls,
87-
"num_params_M": num_params_M,
88-
"flops_M": flops_M,
87+
"num_params_B": num_params_B,
88+
"flops_G": flops_G,
8989
"time_plain_s": time_plain_s,
9090
"mem_plain_GB": mem_plain_GB,
9191
"time_compile_s": time_compile_s,

benchmarks/push_results.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import pandas as pd
24
from huggingface_hub import hf_hub_download, upload_file
35
from huggingface_hub.utils import EntryNotFoundError
@@ -50,7 +52,7 @@ def push_to_hf_dataset():
5052

5153
# combine
5254
current_results[column] = curr_str + append_str
53-
55+
os.remove(FINAL_CSV_FILENAME)
5456
current_results.to_csv(FINAL_CSV_FILENAME, index=False)
5557

5658
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"

0 commit comments

Comments
 (0)