|
1 | 1 | import optuna |
| 2 | +from optuna.storages import JournalStorage, JournalFileBackend |
| 3 | +from optuna.storages.journal import JournalFileOpenLock |
2 | 4 | import os |
3 | 5 | import mlflow |
4 | 6 | from datetime import datetime |
|
22 | 24 |
|
23 | 25 |
|
24 | 26 | mlflow.set_tracking_uri(uri=f"http://127.0.0.1:{MLFLOW_PORT}") |
| 27 | +mlflow.set_experiment(EXPERIMENT_NAME) |
25 | 28 |
|
26 | 29 |
|
| 30 | +# Optuna Storage Essentials |
| 31 | +# Use JournalFileStorage to ensure concurrency safety |
27 | 32 |
|
28 | | -mlflow.set_experiment(EXPERIMENT_NAME) |
| 33 | +storage_file = f"./optuna_{EXPERIMENT_NAME}.log" |
| 34 | +lock_obj = JournalFileOpenLock(storage_file) |
29 | 35 |
|
| 36 | +# Create the JournalStorage instance |
| 37 | +optuna_storage = JournalStorage(JournalFileBackend(storage_file, lock_obj=lock_obj)) |
30 | 38 |
|
31 | 39 |
|
32 | 40 | def objective(trial: optuna.Trial) -> float: |
@@ -1445,8 +1453,7 @@ def main(): |
1445 | 1453 | n_trials = N_TRIALS |
1446 | 1454 | # mlflow_parent = mlflow.start_run(run_name=os.getenv("MLFLOW_PARENT_RUN_NAME", "cerebros_poc_parent"), tags={"phase": "poc", "mode": "fast" if fast else "full"}) |
1447 | 1455 | sampler = optuna.samplers.TPESampler(multivariate=True, n_startup_trials=5) |
1448 | | - storage_name = f"sqlite:///{EXPERIMENT_NAME}.db |
1449 | | - study = optuna.create_study(direction="minimize", sampler=sampler, storage=storage_name) |
| 1456 | + study = optuna.create_study(direction="minimize", sampler=sampler, storage=optuna_storage) |
1450 | 1457 | study.optimize(objective, n_trials=n_trials) |
1451 | 1458 | # mlflow.log_param("n_trials", n_trials) |
1452 | 1459 | # Log fixed (non-tunable) generation control param once at parent level |
|
0 commit comments