|
| 1 | +import datetime |
1 | 2 | import os |
| 3 | +import uuid |
2 | 4 |
|
3 | 5 | import pandas as pd |
4 | 6 | import psycopg2 |
5 | 7 | import psycopg2.extras |
6 | 8 |
|
7 | 9 |
|
8 | | -FINAL_CSV_FILENAME = "benchmark_outputs/collated_results.csv" |
9 | | -TABLE_NAME = "diffusers_benchmarks" |
| 10 | +# FINAL_CSV_FILENAME = "benchmark_outputs/collated_results.csv" |
| 11 | +# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27 |
| 12 | +TABLE_NAME = "model_measurements" |
10 | 13 |
|
11 | 14 | if __name__ == "__main__": |
12 | | - conn = psycopg2.connect( |
13 | | - host=os.getenv("PGHOST"), |
14 | | - database=os.getenv("PGDATABASE"), |
15 | | - user=os.getenv("PGUSER"), |
16 | | - password=os.getenv("PGPASSWORD"), |
17 | | - ) |
| 15 | + try: |
| 16 | + conn = psycopg2.connect( |
| 17 | + host=os.getenv("PGHOST"), |
| 18 | + database=os.getenv("PGDATABASE"), |
| 19 | + user=os.getenv("PGUSER"), |
| 20 | + password=os.getenv("PGPASSWORD"), |
| 21 | + ) |
| 22 | + print("DB connection established successfully.") |
| 23 | + except Exception: |
| 24 | + raise |
18 | 25 | cur = conn.cursor() |
19 | 26 |
|
20 | | - cur.execute(f""" |
21 | | - CREATE TABLE IF NOT EXISTS {TABLE_NAME} ( |
22 | | - scenario TEXT, |
23 | | - model_cls TEXT, |
24 | | - num_params_M REAL, |
25 | | - flops_M REAL, |
26 | | - time_plain_s REAL, |
27 | | - mem_plain_GB REAL, |
28 | | - time_compile_s REAL, |
29 | | - mem_compile_GB REAL, |
30 | | - fullgraph BOOLEAN, |
31 | | - mode TEXT, |
32 | | - github_sha TEXT |
33 | | - ); |
34 | | - """) |
35 | | - conn.commit() |
36 | | - |
37 | | - df = pd.read_csv(FINAL_CSV_FILENAME) |
| 27 | + # df = pd.read_csv(FINAL_CSV_FILENAME) |
| 28 | + df = pd.read_csv("collated_results.csv") |
38 | 29 |
|
39 | 30 | # Helper to cast values (or None) given a dtype |
40 | 31 | def _cast_value(val, dtype: str): |
@@ -64,61 +55,60 @@ def _cast_value(val, dtype: str): |
64 | 55 |
|
65 | 56 | return val |
66 | 57 |
|
67 | | - rows_to_insert = [] |
68 | | - for _, row in df.iterrows(): |
69 | | - scenario = _cast_value(row.get("scenario"), "text") |
70 | | - model_cls = _cast_value(row.get("model_cls"), "text") |
71 | | - num_params_M = _cast_value(row.get("num_params_M"), "float") |
72 | | - flops_M = _cast_value(row.get("flops_M"), "float") |
73 | | - time_plain_s = _cast_value(row.get("time_plain_s"), "float") |
74 | | - mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float") |
75 | | - time_compile_s = _cast_value(row.get("time_compile_s"), "float") |
76 | | - mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float") |
77 | | - fullgraph = _cast_value(row.get("fullgraph"), "bool") |
78 | | - mode = _cast_value(row.get("mode"), "text") |
79 | | - |
80 | | - # If "github_sha" column exists in the CSV, cast it; else default to None |
81 | | - if "github_sha" in df.columns: |
82 | | - github_sha = _cast_value(row.get("github_sha"), "text") |
83 | | - else: |
84 | | - github_sha = None |
85 | | - |
86 | | - rows_to_insert.append( |
87 | | - ( |
88 | | - scenario, |
89 | | - model_cls, |
90 | | - num_params_M, |
91 | | - flops_M, |
92 | | - time_plain_s, |
93 | | - mem_plain_GB, |
94 | | - time_compile_s, |
95 | | - mem_compile_GB, |
96 | | - fullgraph, |
97 | | - mode, |
98 | | - github_sha, |
99 | | - ) |
| 58 | + try: |
| 59 | + rows_to_insert = [] |
| 60 | + id_for_benchmark = str(uuid.uuid4()) + "_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S") |
| 61 | + for _, row in df.iterrows(): |
| 62 | + scenario = _cast_value(row.get("scenario"), "text") |
| 63 | + model_cls = _cast_value(row.get("model_cls"), "text") |
| 64 | + num_params_M = _cast_value(row.get("num_params_M"), "float") |
| 65 | + flops_M = _cast_value(row.get("flops_M"), "float") |
| 66 | + time_plain_s = _cast_value(row.get("time_plain_s"), "float") |
| 67 | + mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float") |
| 68 | + time_compile_s = _cast_value(row.get("time_compile_s"), "float") |
| 69 | + mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float") |
| 70 | + fullgraph = _cast_value(row.get("fullgraph"), "bool") |
| 71 | + mode = _cast_value(row.get("mode"), "text") |
| 72 | + |
| 73 | + # If "github_sha" column exists in the CSV, cast it; else default to None |
| 74 | + if "github_sha" in df.columns: |
| 75 | + github_sha = _cast_value(row.get("github_sha"), "text") |
| 76 | + else: |
| 77 | + github_sha = None |
| 78 | + |
| 79 | + if github_sha: |
| 80 | + benchmark_id = f"{model_cls}-{scenario}-{github_sha}" |
| 81 | + else: |
| 82 | + benchmark_id = f"{model_cls}-{scenario}-{id_for_benchmark}" |
| 83 | + |
| 84 | + measurements = { |
| 85 | + "scenario": scenario, |
| 86 | + "model_cls": model_cls, |
| 87 | + "num_params_M": num_params_M, |
| 88 | + "flops_M": flops_M, |
| 89 | + "time_plain_s": time_plain_s, |
| 90 | + "mem_plain_GB": mem_plain_GB, |
| 91 | + "time_compile_s": time_compile_s, |
| 92 | + "mem_compile_GB": mem_compile_GB, |
| 93 | + "fullgraph": fullgraph, |
| 94 | + "mode": mode, |
| 95 | + "github_sha": github_sha, |
| 96 | + } |
| 97 | + rows_to_insert.append((benchmark_id, measurements)) |
| 98 | + |
| 99 | + # Batch-insert all rows |
| 100 | + insert_sql = f""" |
| 101 | + INSERT INTO {TABLE_NAME} ( |
| 102 | + benchmark_id, |
| 103 | + measurements |
100 | 104 | ) |
| 105 | + VALUES (%s, %s); |
| 106 | + """ |
| 107 | + |
| 108 | + psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert) |
| 109 | + conn.commit() |
101 | 110 |
|
102 | | - # Batch-insert all rows (with NULL for any None) |
103 | | - insert_sql = """ |
104 | | - INSERT INTO benchmarks ( |
105 | | - scenario, |
106 | | - model_cls, |
107 | | - num_params_M, |
108 | | - flops_M, |
109 | | - time_plain_s, |
110 | | - mem_plain_GB, |
111 | | - time_compile_s, |
112 | | - mem_compile_GB, |
113 | | - fullgraph, |
114 | | - mode, |
115 | | - github_sha |
116 | | - ) |
117 | | - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); |
118 | | - """ |
119 | | - |
120 | | - psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert) |
121 | | - conn.commit() |
122 | | - |
123 | | - cur.close() |
124 | | - conn.close() |
| 111 | + cur.close() |
| 112 | + conn.close() |
| 113 | + except Exception as e: |
| 114 | + print(f"Exception: {e}") |
0 commit comments