Skip to content

Commit f29b52a

Browse files
committed
try to fix
1 parent b57c002 commit f29b52a

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

autointent/context/_context.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,19 @@ def dump(self) -> None:
8888
yaml.dump(inference_config, file)
8989

9090
def load(self) -> None:
91-
"""Load all information about optimization process from disk."""
91+
"""Restore the context state to resume the optimization process.
92+
93+
Raises:
94+
RuntimeError: If the modules artifacts are not found.
95+
"""
9296
self._logger.debug("loading logs...")
9397
logs_dir = self.logging_config.dirpath
9498
self.optimization_info.load(logs_dir)
9599
if not self.optimization_info.artifacts.has_artifacts():
96100
msg = (
97101
"It is impossible to continue from the previous point, "
98-
"start again with dump_modules=True settings if you want to resume the run"
102+
"start again with dump_modules=True settings if you want to resume the run."
103+
"To load optimization info only, use Context.optimization_info.load(logs_dir)."
99104
)
100105
raise RuntimeError(msg)
101106

autointent/nodes/_node_optimizer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gc
44
import itertools as it
5+
import json
56
import logging
67
from copy import deepcopy
78
from functools import partial
@@ -106,16 +107,13 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1)
106107

107108
study, finished_trials, n_trials = load_or_create_study(
108109
study_name=f"{self.node_info.node_type}_{module_name}",
109-
storage_dir=context.get_dump_dir(),
110+
context=context,
110111
direction="maximize",
111112
sampler=sampler_instance,
112113
n_trials=n_trials,
113114
)
114115
self._counter = max(self._counter, finished_trials)
115116

116-
if n_trials == 0:
117-
context.load()
118-
119117
optuna.logging.set_verbosity(optuna.logging.WARNING)
120118
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)
121119

@@ -143,7 +141,7 @@ def objective(
143141
"""
144142
config = self.suggest(trial, search_space)
145143

146-
self._logger.debug("Initializing %s module...", module_name)
144+
self._logger.debug("Initializing %s module with config: %s", module_name, json.dumps(config))
147145
module = self.node_info.modules_available[module_name].from_context(context, **config)
148146

149147
embedder_config = module.get_embedder_config()
@@ -338,7 +336,7 @@ def get_storage_url(study_name: str, storage_dir: Path | None) -> str | None:
338336

339337
def load_or_create_study(
340338
study_name: str,
341-
storage_dir: Path | None,
339+
context: Context,
342340
sampler: optuna.samplers.BaseSampler,
343341
direction: str = "maximize",
344342
n_trials: int = 10,
@@ -347,7 +345,7 @@ def load_or_create_study(
347345
348346
Args:
349347
study_name: Name of the study
350-
storage_dir: Directory where study databases are stored
348+
context: Context object
351349
direction: Optimization direction (maximize or minimize)
352350
sampler: Optuna sampler instance
353351
n_trials: n_trials
@@ -358,7 +356,7 @@ def load_or_create_study(
358356
remaining_trials = n_trials
359357
finished_trials = 0
360358

361-
storage_url = get_storage_url(study_name, storage_dir)
359+
storage_url = get_storage_url(study_name, context.get_dump_dir())
362360

363361
try:
364362
# will catch exception if study does not exist
@@ -373,6 +371,8 @@ def load_or_create_study(
373371
finished_trials = max(t.number for t in study.trials) + 1
374372
# Calculate remaining trials if n_trials is specified
375373
remaining_trials = n_trials if n_trials is None else max(0, n_trials - len(study.trials))
374+
if remaining_trials == 0:
375+
context.load()
376376
return study, finished_trials, remaining_trials # noqa: TRY300
377377
except Exception: # noqa: BLE001
378378
# Create a new study if none exists

0 commit comments

Comments
 (0)