1- import  datetime 
1+ import  argparse 
22import  os 
3- import  uuid 
3+ import  sys 
44
55import  pandas  as  pd 
66import  psycopg2 
1212BENCHMARKS_TABLE_NAME  =  "benchmarks" 
1313MEASUREMENTS_TABLE_NAME  =  "model_measurements" 
1414
15+ 
16+ def  _init_benchmark (conn , branch , commit_id , commit_msg ):
17+     metadata  =  {}
18+     repository  =  "huggingface/diffusers" 
19+     with  conn .cursor () as  cur :
20+         cur .execute (
21+             f"INSERT INTO { BENCHMARKS_TABLE_NAME }  ,
22+             (repository , branch , commit_id , commit_msg , metadata ),
23+         )
24+         benchmark_id  =  cur .fetchone ()[0 ]
25+         print (f"Initialised benchmark #{ benchmark_id }  )
26+         return  benchmark_id 
27+ 
28+ 
29+ def  parse_args ():
30+     parser  =  argparse .ArgumentParser ()
31+     parser .add_argument (
32+         "branch" ,
33+         type = str ,
34+         help = "The branch name on which the benchmarking is performed." ,
35+     )
36+ 
37+     parser .add_argument (
38+         "commit_id" ,
39+         type = str ,
40+         help = "The commit hash on which the benchmarking is performed." ,
41+     )
42+ 
43+     parser .add_argument (
44+         "commit_msg" ,
45+         type = str ,
46+         help = "The commit message associated with the commit, truncated to 70 characters." ,
47+     )
48+     args  =  parser .parse_args ()
49+     return  args 
50+ 
51+ 
1552if  __name__  ==  "__main__" :
53+     args  =  parse_args ()
1654    try :
1755        conn  =  psycopg2 .connect (
1856            host = os .getenv ("PGHOST" ),
2159            password = os .getenv ("PGPASSWORD" ),
2260        )
2361        print ("DB connection established successfully." )
24-     except  Exception :
25-         raise 
62+     except  Exception  as  e :
63+         print (f"Problem during DB init: { e }  )
64+         sys .exit (1 )
65+ 
66+     benchmark_id  =  _init_benchmark (
67+         conn = conn ,
68+         branch = args .branch ,
69+         commit_id = args .commit_id ,
70+         commit_msg = args .commit_msg ,
71+     )
72+ 
2673    cur  =  conn .cursor ()
2774
2875    df  =  pd .read_csv (FINAL_CSV_FILENAME )
@@ -57,7 +104,6 @@ def _cast_value(val, dtype: str):
57104
58105    try :
59106        rows_to_insert  =  []
60-         id_for_benchmark  =  str (uuid .uuid4 ()) +  "_"  +  datetime .datetime .now ().strftime ("%Y%m%d%H%M%S" )
61107        for  _ , row  in  df .iterrows ():
62108            scenario  =  _cast_value (row .get ("scenario" ), "text" )
63109            model_cls  =  _cast_value (row .get ("model_cls" ), "text" )
@@ -76,13 +122,7 @@ def _cast_value(val, dtype: str):
76122            else :
77123                github_sha  =  None 
78124
79-             if  github_sha :
80-                 benchmark_id  =  f"{ model_cls } { scenario } { github_sha }  
81-             else :
82-                 benchmark_id  =  f"{ model_cls } { scenario } { id_for_benchmark }  
83- 
84125            measurements  =  {
85-                 "repository" : "huggingface/diffusers" ,
86126                "scenario" : scenario ,
87127                "model_cls" : model_cls ,
88128                "num_params_B" : num_params_B ,
0 commit comments