|
4 | 4 | from datetime import datetime |
5 | 5 | import uuid |
6 | 6 | import json |
| 7 | +import time |
| 8 | + |
7 | 9 | # get environment variables |
8 | 10 | DB_HOST = os.environ.get('DB_HOST', 'localhost') |
9 | 11 | DB_PORT = os.environ.get('DB_PORT', '5432') |
@@ -40,13 +42,27 @@ def create_trial( model_id, experiment_id, cur, conn,source_id="",completed_at=N |
40 | 42 | return trial_id |
41 | 43 |
|
42 | 44 | def create_trial_inputs(trial_id, inputs, cur, conn): |
43 | | - cur.execute("SELECT MAX(id) as id FROM trial_inputs") |
44 | | - try: |
45 | | - max_id = int(cur.fetchone()["id"]) |
46 | | - except: |
47 | | - max_id = 0 |
48 | | - cur.execute("INSERT INTO trial_inputs (id,created_at,updated_at,trial_id, url) VALUES (%s,%s,%s,%s, %s)", (max_id+1, datetime.now(), datetime.now(),trial_id, json.dumps(inputs))) |
49 | | - conn.commit() |
| 45 | + while True: |
| 46 | + try: |
| 47 | + # Fetch the latest max ID |
| 48 | + cur.execute("SELECT MAX(id) as id FROM trial_inputs") |
| 49 | + max_id = cur.fetchone()["id"] |
| 50 | + max_id = int(max_id) if max_id is not None else 0 # Handle NULL case |
| 51 | + |
| 52 | + # Attempt to insert with incremented ID |
| 53 | + cur.execute(""" |
| 54 | + INSERT INTO trial_inputs (id, created_at, updated_at, trial_id, url) |
| 55 | + VALUES (%s, %s, %s, %s, %s) |
| 56 | + """, (max_id + 1, datetime.now(), datetime.now(), trial_id, json.dumps(inputs))) |
| 57 | + |
| 58 | + conn.commit() |
| 59 | + break # Exit loop on success |
| 60 | + |
| 61 | + except Exception as e: |
| 62 | + print(f"Unexpected error: {e}") |
| 63 | + conn.rollback() # Rollback in case of duplicate key violation |
| 64 | + time.sleep(0.1) # Small delay before retrying to avoid excessive looping |
| 65 | + continue # Retry fetching max_id and inserting again |
50 | 66 |
|
51 | 67 | def create_expriement( cur, conn): |
52 | 68 | experiment_id= str(uuid.uuid4()) |
|
0 commit comments