Skip to content

Commit e8145e0

Browse files
authored
Merge pull request freqtrade#11805 from viotemp1/optuna_addons
fix hyperopt repeated parameters between batches
2 parents 6188694 + ae90738 commit e8145e0

File tree

2 files changed

+53
-44
lines changed

2 files changed

+53
-44
lines changed

freqtrade/optimize/hyperopt/hyperopt.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515

1616
import rapidjson
1717
from joblib import Parallel, cpu_count
18+
from optuna.trial import FrozenTrial, Trial, TrialState
1819

1920
from freqtrade.constants import FTHYPT_FILEVERSION, LAST_BT_RESULT_FN, Config
2021
from freqtrade.enums import HyperoptState
2122
from freqtrade.exceptions import OperationalException
2223
from freqtrade.misc import file_dump_json, plural
2324
from freqtrade.optimize.hyperopt.hyperopt_logger import logging_mp_handle, logging_mp_setup
24-
from freqtrade.optimize.hyperopt.hyperopt_optimizer import HyperOptimizer
25+
from freqtrade.optimize.hyperopt.hyperopt_optimizer import INITIAL_POINTS, HyperOptimizer
2526
from freqtrade.optimize.hyperopt.hyperopt_output import HyperoptOutput
2627
from freqtrade.optimize.hyperopt_tools import (
2728
HyperoptStateContainer,
@@ -34,9 +35,6 @@
3435
logger = logging.getLogger(__name__)
3536

3637

37-
INITIAL_POINTS = 30
38-
39-
4038
log_queue: Any
4139

4240

@@ -91,6 +89,7 @@ def __init__(self, config: Config) -> None:
9189
self.print_json = self.config.get("print_json", False)
9290

9391
self.hyperopter = HyperOptimizer(self.config, self.data_pickle_file)
92+
self.count_skipped_epochs = 0
9493

9594
@staticmethod
9695
def get_lock_filename(config: Config) -> str:
@@ -169,56 +168,49 @@ def get_optuna_asked_points(self, n_points: int, dimensions: dict) -> list[Any]:
169168
asked.append(self.opt.ask(dimensions))
170169
return asked
171170

171+
def duplicate_optuna_asked_points(self, trial: Trial, asked_trials: list[FrozenTrial]) -> bool:
172+
asked_trials_no_dups: list[FrozenTrial] = []
173+
trials_to_consider = trial.study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
174+
# Check whether we already evaluated the sampled `params`.
175+
for t in reversed(trials_to_consider):
176+
if trial.params == t.params:
177+
return True
178+
# Check whether same`params` in one batch (asked_trials). Autosampler is doing this.
179+
for t in asked_trials:
180+
if t.params not in asked_trials_no_dups:
181+
asked_trials_no_dups.append(t)
182+
if len(asked_trials_no_dups) != len(asked_trials):
183+
return True
184+
return False
185+
172186
def get_asked_points(self, n_points: int, dimensions: dict) -> tuple[list[Any], list[bool]]:
173187
"""
174188
Enforce points returned from `self.opt.ask` have not been already evaluated
175189
176190
Steps:
177191
1. Try to get points using `self.opt.ask` first
178192
2. Discard the points that have already been evaluated
179-
3. Retry using `self.opt.ask` up to 3 times
180-
4. If still some points are missing in respect to `n_points`, random sample some points
181-
5. Repeat until at least `n_points` points in the `asked_non_tried` list
182-
6. Return a list with length truncated at `n_points`
193+
3. Retry using `self.opt.ask` up to `n_points` times
183194
"""
184-
185-
def unique_list(a_list):
186-
new_list = []
187-
for item in a_list:
188-
if item not in new_list:
189-
new_list.append(item)
190-
return new_list
191-
195+
asked_non_tried: list[FrozenTrial] = []
196+
optuna_asked_trials = self.get_optuna_asked_points(n_points=n_points, dimensions=dimensions)
197+
asked_non_tried += [
198+
x
199+
for x in optuna_asked_trials
200+
if not self.duplicate_optuna_asked_points(x, optuna_asked_trials)
201+
]
192202
i = 0
193-
asked_non_tried: list[list[Any]] = []
194-
is_random_non_tried: list[bool] = []
195-
while i < 5 and len(asked_non_tried) < n_points:
196-
if i < 3:
197-
self.opt.cache_ = {}
198-
asked = unique_list(
199-
self.get_optuna_asked_points(
200-
n_points=n_points * 5 if i > 0 else n_points, dimensions=dimensions
201-
)
202-
)
203-
is_random = [False for _ in range(len(asked))]
204-
else:
205-
asked = unique_list(self.opt.space.rvs(n_samples=n_points * 5))
206-
is_random = [True for _ in range(len(asked))]
207-
is_random_non_tried += [
208-
rand for x, rand in zip(asked, is_random, strict=False) if x not in asked_non_tried
209-
]
210-
asked_non_tried += [x for x in asked if x not in asked_non_tried]
203+
while i < 2 * n_points and len(asked_non_tried) < n_points:
204+
asked_new = self.get_optuna_asked_points(n_points=1, dimensions=dimensions)[0]
205+
if not self.duplicate_optuna_asked_points(asked_new, asked_non_tried):
206+
asked_non_tried.append(asked_new)
211207
i += 1
208+
if len(asked_non_tried) < n_points:
209+
if self.count_skipped_epochs == 0:
210+
logger.warning("Duplicate params detected. Maybe your search space is too small?")
211+
self.count_skipped_epochs += n_points - len(asked_non_tried)
212212

213-
if asked_non_tried:
214-
return (
215-
asked_non_tried[: min(len(asked_non_tried), n_points)],
216-
is_random_non_tried[: min(len(asked_non_tried), n_points)],
217-
)
218-
else:
219-
return self.get_optuna_asked_points(n_points=n_points, dimensions=dimensions), [
220-
False for _ in range(n_points)
221-
]
213+
return asked_non_tried, [False for _ in range(len(asked_non_tried))]
222214

223215
def evaluate_result(self, val: dict[str, Any], current: int, is_random: bool):
224216
"""
@@ -304,6 +296,7 @@ def start(self) -> None:
304296
parallel,
305297
[asked1.params for asked1 in asked],
306298
)
299+
307300
f_val_loss = [v["loss"] for v in f_val]
308301
for o_ask, v in zip(asked, f_val_loss, strict=False):
309302
self.opt.tell(o_ask, v)
@@ -327,6 +320,12 @@ def start(self) -> None:
327320
except KeyboardInterrupt:
328321
print("User interrupted..")
329322

323+
if self.count_skipped_epochs > 0:
324+
logger.info(
325+
f"{self.count_skipped_epochs} {plural(self.count_skipped_epochs, 'epoch')} "
326+
f"skipped due to duplicate parameters."
327+
)
328+
330329
logger.info(
331330
f"{self.num_epochs_saved} {plural(self.num_epochs_saved, 'epoch')} "
332331
f"saved to '{self.results_file}'."

freqtrade/optimize/hyperopt/hyperopt_optimizer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
logger = logging.getLogger(__name__)
4747

48+
INITIAL_POINTS = 30
4849

4950
MAX_LOSS = 100000 # just a big enough number to be bad result in loss optimization
5051

@@ -425,7 +426,16 @@ def get_optimizer(
425426
raise OperationalException(f"Optuna Sampler {o_sampler} not supported.")
426427
with warnings.catch_warnings():
427428
warnings.filterwarnings(action="ignore", category=ExperimentalWarning)
428-
sampler = optuna_samplers_dict[o_sampler](seed=random_state)
429+
if o_sampler in ["NSGAIIISampler", "NSGAIISampler"]:
430+
sampler = optuna_samplers_dict[o_sampler](
431+
seed=random_state, population_size=INITIAL_POINTS
432+
)
433+
elif o_sampler in ["GPSampler", "TPESampler", "CmaEsSampler"]:
434+
sampler = optuna_samplers_dict[o_sampler](
435+
seed=random_state, n_startup_trials=INITIAL_POINTS
436+
)
437+
else:
438+
sampler = optuna_samplers_dict[o_sampler](seed=random_state)
429439
else:
430440
sampler = o_sampler
431441

0 commit comments

Comments
 (0)