22
33import gc
44import itertools as it
5+ import json
56import logging
67from copy import deepcopy
78from functools import partial
@@ -106,16 +107,13 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1)
106107
107108 study , finished_trials , n_trials = load_or_create_study (
108109 study_name = f"{ self .node_info .node_type } _{ module_name } " ,
109- storage_dir = context . get_dump_dir () ,
110+ context = context ,
110111 direction = "maximize" ,
111112 sampler = sampler_instance ,
112113 n_trials = n_trials ,
113114 )
114115 self ._counter = max (self ._counter , finished_trials )
115116
116- if n_trials == 0 :
117- context .load ()
118-
119117 optuna .logging .set_verbosity (optuna .logging .WARNING )
120118 obj = partial (self .objective , module_name = module_name , search_space = search_space , context = context )
121119
@@ -143,7 +141,7 @@ def objective(
143141 """
144142 config = self .suggest (trial , search_space )
145143
146- self ._logger .debug ("Initializing %s module... " , module_name )
144+ self ._logger .debug ("Initializing %s module with config: %s " , module_name , json . dumps ( config ) )
147145 module = self .node_info .modules_available [module_name ].from_context (context , ** config )
148146
149147 embedder_config = module .get_embedder_config ()
@@ -338,7 +336,7 @@ def get_storage_url(study_name: str, storage_dir: Path | None) -> str | None:
338336
339337def load_or_create_study (
340338 study_name : str ,
341- storage_dir : Path | None ,
339+ context : Context ,
342340 sampler : optuna .samplers .BaseSampler ,
343341 direction : str = "maximize" ,
344342 n_trials : int = 10 ,
@@ -347,7 +345,7 @@ def load_or_create_study(
347345
348346 Args:
349347 study_name: Name of the study
350- storage_dir: Directory where study databases are stored
348+ context: Context object
351349 direction: Optimization direction (maximize or minimize)
352350 sampler: Optuna sampler instance
353351 n_trials: n_trials
@@ -358,7 +356,7 @@ def load_or_create_study(
358356 remaining_trials = n_trials
359357 finished_trials = 0
360358
361- storage_url = get_storage_url (study_name , storage_dir )
359+ storage_url = get_storage_url (study_name , context . get_dump_dir () )
362360
363361 try :
364362 # will catch exception if study does not exist
@@ -373,6 +371,8 @@ def load_or_create_study(
373371 finished_trials = max (t .number for t in study .trials ) + 1
374372 # Calculate remaining trials if n_trials is specified
375373 remaining_trials = n_trials if n_trials is None else max (0 , n_trials - len (study .trials ))
374+ if remaining_trials == 0 :
375+ context .load ()
376376 return study , finished_trials , remaining_trials # noqa: TRY300
377377 except Exception : # noqa: BLE001
378378 # Create a new study if none exists
0 commit comments