Skip to content

Commit dfae7ca

Browse files
committed
fix duplicate params in same batch also
1 parent 17cd045 commit dfae7ca

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

freqtrade/optimize/hyperopt/hyperopt.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import rapidjson
1717
from joblib import Parallel, cpu_count
18-
from optuna.trial import Trial, TrialState
18+
from optuna.trial import FrozenTrial, Trial, TrialState
1919

2020
from freqtrade.constants import FTHYPT_FILEVERSION, LAST_BT_RESULT_FN, Config
2121
from freqtrade.enums import HyperoptState
@@ -171,15 +171,19 @@ def get_optuna_asked_points(self, n_points: int, dimensions: dict) -> list[Any]:
171171
asked.append(self.opt.ask(dimensions))
172172
return asked
173173

174-
def duplicate_optuna_asked_points(self, trial: Trial) -> bool:
174+
def duplicate_optuna_asked_points(self, trial: Trial, asked_trials: list[FrozenTrial]) -> bool:
175+
asked_trials_no_dups: list[FrozenTrial] = []
175176
trials_to_consider = trial.study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
176177
# Check whether we already evaluated the sampled `params`.
177178
for t in reversed(trials_to_consider):
178179
if trial.params == t.params:
179-
# logger.warning(
180-
# f"duplicate trial: Trial {trial.number} has same params as {t.number}"
181-
# )
182180
return True
181+
# Check whether same`params` in one batch (asked_trials). Autosampler is doing this.
182+
for t in asked_trials:
183+
if t.params not in asked_trials_no_dups:
184+
asked_trials_no_dups.append(t)
185+
if len(asked_trials_no_dups) != len(asked_trials):
186+
return True
183187
return False
184188

185189
def get_asked_points(self, n_points: int, dimensions: dict) -> tuple[list[Any], list[bool]]:
@@ -189,26 +193,26 @@ def get_asked_points(self, n_points: int, dimensions: dict) -> tuple[list[Any],
189193
Steps:
190194
1. Try to get points using `self.opt.ask` first
191195
2. Discard the points that have already been evaluated
192-
3. Retry using `self.opt.ask` up to 5 times
196+
3. Retry using `self.opt.ask` up to `n_points` times
193197
"""
194-
asked_non_tried: list[list[Any]] = []
195-
asked_duplicates: list[Trial] = []
198+
asked_non_tried: list[FrozenTrial] = []
196199
optuna_asked_trials = self.get_optuna_asked_points(n_points=n_points, dimensions=dimensions)
197200
asked_non_tried += [
198-
x for x in optuna_asked_trials if not self.duplicate_optuna_asked_points(x)
201+
x
202+
for x in optuna_asked_trials
203+
if not self.duplicate_optuna_asked_points(x, optuna_asked_trials)
199204
]
200205
i = 0
201-
while i < 5 and len(asked_non_tried) < n_points:
206+
while i < 2 * n_points and len(asked_non_tried) < n_points:
202207
asked_new = self.get_optuna_asked_points(n_points=1, dimensions=dimensions)[0]
203-
if not self.duplicate_optuna_asked_points(asked_new):
208+
if not self.duplicate_optuna_asked_points(asked_new, asked_non_tried):
204209
asked_non_tried.append(asked_new)
205-
else:
206-
asked_duplicates.append(asked_new)
207210
i += 1
208-
if len(asked_duplicates) > 0 and len(asked_non_tried) < n_points:
209-
for asked_duplicate in asked_duplicates:
210-
logger.warning(f"duplicate params for Epoch {asked_duplicate.number}")
211-
self.count_skipped_epochs += len(asked_duplicates)
211+
if len(asked_non_tried) < n_points:
212+
logger.warning(
213+
"duplicate params detected. Please check if search space is not too small!"
214+
)
215+
self.count_skipped_epochs += n_points - len(asked_non_tried)
212216

213217
return asked_non_tried, [False for _ in range(len(asked_non_tried))]
214218

0 commit comments

Comments
 (0)