Skip to content

Commit 7b60adf

Browse files
committed
updated mpi doptuna
1 parent d1a8349 commit 7b60adf

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

deephyper_benchmark/search/_mpi_doptuna.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from deephyper.core.utils._timeout import terminate_on_timeout # noqa: E402
2222
from deephyper.evaluator import RunningJob # noqa: E402
2323
from deephyper.search import Search # noqa: E402
24+
from deephyper.skopt.moo import non_dominated_set
2425

2526

2627
def optuna_suggest_from_hp(trial, cs_hp):
@@ -284,7 +285,7 @@ def objective_wrapper(trial):
284285
constraints = []
285286
for i, lbi in enumerate(self._moo_lower_bounds):
286287
if lbi is not None and type(output["objective"][i]) is not str:
287-
ci = -(output["objective"][i] - lbi) # <= 0
288+
ci = -(output["objective"][i] - lbi) # <= 0
288289
constraints.append(ci)
289290
trial.set_user_attr("constraints", tuple(constraints))
290291

@@ -356,8 +357,31 @@ def optimize_wrapper(duration):
356357

357358
df_results = pd.DataFrame([t.user_attrs["results"] for t in all_trials])
358359
df_path = os.path.join(self._log_dir, "results.csv")
360+
361+
# Check if Multi-Objective Optimization was performed to save the pareto front
362+
objective_columns = [
363+
col for col in df_results.columns if col.startswith("objective")
364+
]
365+
366+
if len(objective_columns) > 1:
367+
if pd.api.types.is_string_dtype(df_results[objective_columns[0]]):
368+
mask_no_failures = ~df_results[objective_columns[0]].str.startswith(
369+
"F"
370+
)
371+
else:
372+
mask_no_failures = np.ones(len(df_results), dtype=bool)
373+
objectives = -df_results.loc[
374+
mask_no_failures, objective_columns
375+
].values.astype(float)
376+
mask_pareto_front = non_dominated_set(objectives)
377+
df_results["pareto_efficient"] = False
378+
df_results.loc[mask_no_failures, "pareto_efficient"] = mask_pareto_front
379+
359380
df_results.to_csv(df_path, index=False)
360381

361-
self.extend_results_with_pareto_efficient(df_path)
382+
if self.comm:
383+
self.comm.Barrier()
384+
362385
df_results = pd.read_csv(df_path)
386+
363387
return df_results

0 commit comments

Comments
 (0)