1- import datetime
1+ import argparse
22import os
3- import uuid
3+ import sys
44
55import pandas as pd
66import psycopg2
1212BENCHMARKS_TABLE_NAME = "benchmarks"
1313MEASUREMENTS_TABLE_NAME = "model_measurements"
1414
15+
16+ def _init_benchmark (conn , branch , commit_id , commit_msg ):
17+ metadata = {}
18+ repository = "huggingface/diffusers"
19+ with conn .cursor () as cur :
20+ cur .execute (
21+ f"INSERT INTO { BENCHMARKS_TABLE_NAME } (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id" ,
22+ (repository , branch , commit_id , commit_msg , metadata ),
23+ )
24+ benchmark_id = cur .fetchone ()[0 ]
25+ print (f"Initialised benchmark #{ benchmark_id } " )
26+ return benchmark_id
27+
28+
29+ def parse_args ():
30+ parser = argparse .ArgumentParser ()
31+ parser .add_argument (
32+ "branch" ,
33+ type = str ,
34+ help = "The branch name on which the benchmarking is performed." ,
35+ )
36+
37+ parser .add_argument (
38+ "commit_id" ,
39+ type = str ,
40+ help = "The commit hash on which the benchmarking is performed." ,
41+ )
42+
43+ parser .add_argument (
44+ "commit_msg" ,
45+ type = str ,
46+ help = "The commit message associated with the commit, truncated to 70 characters." ,
47+ )
48+ args = parser .parse_args ()
49+ return args
50+
51+
1552if __name__ == "__main__" :
53+ args = parse_args ()
1654 try :
1755 conn = psycopg2 .connect (
1856 host = os .getenv ("PGHOST" ),
2159 password = os .getenv ("PGPASSWORD" ),
2260 )
2361 print ("DB connection established successfully." )
24- except Exception :
25- raise
62+ except Exception as e :
63+ print (f"Problem during DB init: { e } " )
64+ sys .exit (1 )
65+
66+ benchmark_id = _init_benchmark (
67+ conn = conn ,
68+ branch = args .branch ,
69+ commit_id = args .commit_id ,
70+ commit_msg = args .commit_msg ,
71+ )
72+
2673 cur = conn .cursor ()
2774
2875 df = pd .read_csv (FINAL_CSV_FILENAME )
@@ -57,7 +104,6 @@ def _cast_value(val, dtype: str):
57104
58105 try :
59106 rows_to_insert = []
60- id_for_benchmark = str (uuid .uuid4 ()) + "_" + datetime .datetime .now ().strftime ("%Y%m%d%H%M%S" )
61107 for _ , row in df .iterrows ():
62108 scenario = _cast_value (row .get ("scenario" ), "text" )
63109 model_cls = _cast_value (row .get ("model_cls" ), "text" )
@@ -76,13 +122,7 @@ def _cast_value(val, dtype: str):
76122 else :
77123 github_sha = None
78124
79- if github_sha :
80- benchmark_id = f"{ model_cls } -{ scenario } -{ github_sha } "
81- else :
82- benchmark_id = f"{ model_cls } -{ scenario } -{ id_for_benchmark } "
83-
84125 measurements = {
85- "repository" : "huggingface/diffusers" ,
86126 "scenario" : scenario ,
87127 "model_cls" : model_cls ,
88128 "num_params_B" : num_params_B ,
0 commit comments