Skip to content

Commit 6b11973

Browse files
committed
address feedback
1 parent 6bfdae6 commit 6b11973

File tree

2 files changed

+62
-18
lines changed

2 files changed

+62
-18
lines changed

.github/workflows/benchmark.yml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,17 @@ jobs:
6565
path: benchmarks/${{ env.BASE_PATH }}
6666

6767
# TODO: enable this once the connection problem has been resolved.
68-
# - name: Update benchmarking results to DB
69-
# env:
70-
# PGDATABASE: metrics
71-
# PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
72-
# PGUSER: transformers_benchmarks
73-
# PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
74-
# run: cd benchmarks && python populate_into_db.py
68+
- name: Update benchmarking results to DB
69+
env:
70+
PGDATABASE: metrics
71+
PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
72+
PGUSER: transformers_benchmarks
73+
PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
74+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
75+
run: |
76+
commit_id=$GITHUB_SHA
77+
commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
78+
cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
7579
7680
- name: Report success status
7781
if: ${{ success() }}

benchmarks/populate_into_db.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import datetime
1+
import argparse
22
import os
3-
import uuid
3+
import sys
44

55
import pandas as pd
66
import psycopg2
@@ -12,7 +12,45 @@
1212
BENCHMARKS_TABLE_NAME = "benchmarks"
1313
MEASUREMENTS_TABLE_NAME = "model_measurements"
1414

15+
16+
def _init_benchmark(conn, branch, commit_id, commit_msg):
17+
metadata = {}
18+
repository = "huggingface/diffusers"
19+
with conn.cursor() as cur:
20+
cur.execute(
21+
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
22+
(repository, branch, commit_id, commit_msg, metadata),
23+
)
24+
benchmark_id = cur.fetchone()[0]
25+
print(f"Initialised benchmark #{benchmark_id}")
26+
return benchmark_id
27+
28+
29+
def parse_args():
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument(
32+
"branch",
33+
type=str,
34+
help="The branch name on which the benchmarking is performed.",
35+
)
36+
37+
parser.add_argument(
38+
"commit_id",
39+
type=str,
40+
help="The commit hash on which the benchmarking is performed.",
41+
)
42+
43+
parser.add_argument(
44+
"commit_msg",
45+
type=str,
46+
help="The commit message associated with the commit, truncated to 70 characters.",
47+
)
48+
args = parser.parse_args()
49+
return args
50+
51+
1552
if __name__ == "__main__":
53+
args = parse_args()
1654
try:
1755
conn = psycopg2.connect(
1856
host=os.getenv("PGHOST"),
@@ -21,8 +59,17 @@
2159
password=os.getenv("PGPASSWORD"),
2260
)
2361
print("DB connection established successfully.")
24-
except Exception:
25-
raise
62+
except Exception as e:
63+
print(f"Problem during DB init: {e}")
64+
sys.exit(1)
65+
66+
benchmark_id = _init_benchmark(
67+
conn=conn,
68+
branch=args.branch,
69+
commit_id=args.commit_id,
70+
commit_msg=args.commit_msg,
71+
)
72+
2673
cur = conn.cursor()
2774

2875
df = pd.read_csv(FINAL_CSV_FILENAME)
@@ -57,7 +104,6 @@ def _cast_value(val, dtype: str):
57104

58105
try:
59106
rows_to_insert = []
60-
id_for_benchmark = str(uuid.uuid4()) + "_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
61107
for _, row in df.iterrows():
62108
scenario = _cast_value(row.get("scenario"), "text")
63109
model_cls = _cast_value(row.get("model_cls"), "text")
@@ -76,13 +122,7 @@ def _cast_value(val, dtype: str):
76122
else:
77123
github_sha = None
78124

79-
if github_sha:
80-
benchmark_id = f"{model_cls}-{scenario}-{github_sha}"
81-
else:
82-
benchmark_id = f"{model_cls}-{scenario}-{id_for_benchmark}"
83-
84125
measurements = {
85-
"repository": "huggingface/diffusers",
86126
"scenario": scenario,
87127
"model_cls": model_cls,
88128
"num_params_B": num_params_B,

0 commit comments

Comments
 (0)