Skip to content

Commit cfbd21e

Browse files
committed
utility to push to the hub.
1 parent 5635bf8 commit cfbd21e

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

benchmarks/push_results.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
3+
import pandas as pd
4+
import torch
5+
from huggingface_hub import hf_hub_download, upload_file
6+
from huggingface_hub.utils import EntryNotFoundError
7+
8+
9+
if torch.cuda.is_available():
10+
TOTAL_GPU_MEMORY = float(
11+
os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3))
12+
)
13+
else:
14+
raise
15+
16+
REPO_ID = "diffusers/benchmarks"
17+
18+
19+
def has_previous_benchmark() -> str:
20+
from run_all import FINAL_CSV_FILENAME
21+
22+
csv_path = None
23+
try:
24+
csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME)
25+
except EntryNotFoundError:
26+
csv_path = None
27+
return csv_path
28+
29+
30+
def filter_float(value):
31+
if isinstance(value, str):
32+
return float(value.split()[0])
33+
return value
34+
35+
36+
def push_to_hf_dataset():
37+
from run_all import FINAL_CSV_FILENAME, GITHUB_SHA
38+
39+
# If there's an existing benchmark file, we should report the changes.
40+
csv_path = has_previous_benchmark()
41+
if csv_path is not None:
42+
current_results = pd.read_csv(FINAL_CSV_FILENAME)
43+
previous_results = pd.read_csv(csv_path)
44+
45+
# identify the numeric columns we want to annotate
46+
numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
47+
48+
# for each numeric column, append the old value in () if present
49+
for column in numeric_columns:
50+
# coerce any “x units” strings back to float
51+
prev_vals = previous_results[column].map(filter_float)
52+
# align indices in case rows were added/removed
53+
prev_vals = prev_vals.reindex(current_results.index)
54+
55+
# build the new string: "current_value (previous_value)"
56+
curr_str = current_results[column].astype(str)
57+
prev_str = prev_vals.map(lambda x: f" ({x})" if pd.notnull(x) else "")
58+
59+
current_results[column] = curr_str + prev_str
60+
61+
# overwrite the CSV
62+
current_results.to_csv(FINAL_CSV_FILENAME, index=False)
63+
64+
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
65+
upload_file(
66+
repo_id=REPO_ID,
67+
path_in_repo=FINAL_CSV_FILENAME,
68+
path_or_fileobj=FINAL_CSV_FILENAME,
69+
repo_type="dataset",
70+
commit_message=commit_message,
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
push_to_hf_dataset()

benchmarks/run_all.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import glob
2+
import os
23
import subprocess
4+
35
import pandas as pd
4-
import os
6+
57

68
PATTERN = "benchmarking_*.py"
79
FINAL_CSV_FILENAME = "collated_results.csv"
810
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
911

12+
1013
class SubprocessCallException(Exception):
1114
pass
1215

@@ -33,7 +36,7 @@ def run_scripts():
3336
python_files = sorted(glob.glob(PATTERN))
3437

3538
for file in python_files:
36-
if file != "benchmarking_utils.py":
39+
if file != "benchmarking_utils.py":
3740
print(f"****** Running file: {file} ******")
3841
command = f"python {file}"
3942
run_command(command.split())

0 commit comments

Comments
 (0)