Skip to content

Commit 397850c

Browse files
committed
add previous runs to *.json when warm starting with result_logger
1 parent 841db4b commit 397850c

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

hpbandster/core/base_iteration.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class WarmStartIteration(BaseIteration):
254254
iteration that imports a privious Result for warm starting
255255
"""
256256

257-
def __init__(self, Result, config_generator):
257+
def __init__(self, Result, config_generator, result_logger=None):
258258

259259
self.is_finished=False
260260
self.stage = 0
@@ -263,11 +263,14 @@ def __init__(self, Result, config_generator):
263263
id2conf = Result.get_id2config_mapping()
264264
delta_t = - max(map(lambda r: r.time_stamps['finished'], Result.get_all_runs()))
265265

266-
super().__init__(-1, [len(id2conf)] , [None], None)
266+
super().__init__(-1, [len(id2conf)], [None],
267+
None,
268+
result_logger=result_logger)
267269

268270

269271
for i, id in enumerate(id2conf):
270272
new_id = self.add_configuration(config=id2conf[id]['config'], config_info=id2conf[id]['config_info'])
273+
# if result_logger exists, add this config to configs.json
271274

272275
for r in Result.get_runs_by_id(id):
273276

@@ -281,6 +284,9 @@ def __init__(self, Result, config_generator):
281284
j.timestamps[k] = v + delta_t
282285

283286
self.register_result(j , skip_sanity_checks=True)
287+
288+
if self.result_logger:
289+
self.result_logger(j) # add prev jobs to results.json
284290

285291
config_generator.new_result(j, update_model=(i==len(id2conf)-1))
286292

hpbandster/core/master.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,11 @@ def __init__(self,
111111
self.warmstart_iteration = []
112112

113113
else:
114-
self.warmstart_iteration = [WarmStartIteration(previous_result, self.config_generator)]
114+
self.warmstart_iteration = [
115+
WarmStartIteration(previous_result,
116+
self.config_generator,
117+
result_logger=self.result_logger)
118+
]
115119

116120
# condition to synchronize the job_callback and the queue
117121
self.thread_cond = threading.Condition()

0 commit comments

Comments
 (0)