Skip to content

Commit 03c2fbc

Browse files
authored
Fix/duplicated trials (#240)
* implement `_check_duplicate` * run formatter * upd callback test
1 parent eb0a441 commit 03c2fbc

File tree

2 files changed

+19
-481
lines changed

2 files changed

+19
-481
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def objective(
129129
"""
130130
module_name, module_hyperparams = self._suggest_module_and_hyperparams(trial, search_space)
131131

132+
if prev_metric := _check_duplicate(trial):
133+
msg = f"Duplicated trial with {module_name=}, {prev_metric=}, {module_hyperparams=}"
134+
logger.debug(msg)
135+
return prev_metric
136+
132137
self._logger.debug("Initializing %s module with config: %s", module_name, json.dumps(module_hyperparams))
133138
module = self.node_info.modules_available[module_name].from_context(context, **module_hyperparams)
134139
module_hyperparams.update(module.get_implicit_initialization_params())
@@ -425,3 +430,14 @@ def handle_message_on_mode(
425430
logger.warning(message)
426431
if strict:
427432
raise ValueError(message)
433+
434+
435+
# TODO research on possibility to use custom pruner
436+
def _check_duplicate(trial: Trial) -> float | None:
437+
completed_trials = trial.study.get_trials(states=[optuna.trial.TrialState.COMPLETE], deepcopy=False)
438+
439+
previous_trial = next(
440+
(completed_trial for completed_trial in completed_trials if completed_trial.params == trial.params), None
441+
)
442+
443+
return previous_trial.value if previous_trial is not None else None

0 commit comments

Comments
 (0)