Skip to content

Commit 35d77f9

Browse files
authored
Merge pull request #15 from djpasseyjr/increments
Optimizer - add support for saving incremental results to a specific file
2 parents ea27856 + 1ade003 commit 35d77f9

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

rescomp/optimizer/optimizer_controller.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from ipyparallel import Client
55
import sherpa
66
import numpy as np
7+
import os
8+
import dill as pickle
79

810
from .optimizer_systems import get_system, loadprior
911
from .templates import System
@@ -30,7 +32,7 @@ class ResCompOptimizer:
3032
node_count (int): number of nodes in use
3133
"""
3234
def __init__(self, system, map_initial, prediction_type, method, res_ode=None,
33-
add_params=None, rm_params=None, results_directory="", data_directory="",
35+
add_params=None, rm_params=None, results_directory=None, data_directory="", progress_file=None,
3436
parallel=False, parallel_profile=None, **res_params):
3537
"""
3638
Arguments:
@@ -44,8 +46,10 @@ def __init__(self, system, map_initial, prediction_type, method, res_ode=None,
4446
add_params: list of sherpa.Parameter objects to include in optimization
4547
rm_params (list of str): names of optimization parameters to remove
4648
47-
results_directory (str): pathname of where to store optimization results. Default will store in current directory.
49+
results_directory (str or None): pathname of where to store sherpa's optimization results. Default will not save.
4850
data_directory (str): pathname to load additional priors from
51+
progress_file (str or None): file to save optimization progress to. If the file exists, it will be loaded from. More
52+
flexible than results_directory.
4953
5054
parallel (bool): whether to use parallelization. Default false
5155
parallel_profile (str or None): when using parallelization, the ipyparallel profile to connect to.
@@ -60,6 +64,17 @@ def __init__(self, system, map_initial, prediction_type, method, res_ode=None,
6064
self.system = system
6165
self.prediction_type = prediction_type
6266

67+
# Initialize observations
68+
self.progress_file = progress_file
69+
if progress_file is not None and os.path.exists(progress_file):
70+
try:
71+
with open(progress_file, 'rb') as file:
72+
self.opt_observations = pickle.load(file)
73+
except Exception as e:
74+
self.opt_observations = []
75+
else:
76+
self.opt_observations = []
77+
6378
self.parallel = parallel
6479
self.results_directory = results_directory
6580

@@ -84,19 +99,37 @@ def run_optimization(self, opt_ntrials, vpt_reps, algorithm='gpyopt', max_stderr
8499
sherpa_dashboard (bool): whether to use the sherpa dashboard. Default false.
85100
raise_err (bool): whether errors occuring during VPT calculations should be raised; otherwise are suppressed and a VPT of -1 is reported"""
86101
self._initialize_sherpa(opt_ntrials, algorithm, sherpa_dashboard)
102+
# Include pre-observed trials
103+
N = len(self.opt_observations)
104+
for i,obs in enumerate(self.opt_observations):
105+
trial = sherpa.core.Trial(i-N, obs[0])
106+
self.study.add_trial(trial)
107+
self.study.add_observation(trial=trial,
108+
objective=obs[1],
109+
context={'vpt_stderr':obs[2]})
110+
self.study.finalize(trial)
111+
87112
for trial in self.study:
113+
if trial.id < 0:
114+
continue
88115
try:
89116
exp_vpt, stderr = self.run_single_vpt_test(vpt_reps, trial.parameters, raise_err=raise_err)
90117
except Exception as e:
91118
#print relevant information for debugging
92119
print("Trial parameters at error:", trial.parameters)
93120
print("Other parameters:", self.res_params)
94121
raise e
122+
self.opt_observations.append((trial.parameters, exp_vpt, stderr))
95123
self.study.add_observation(trial=trial,
96124
objective=exp_vpt,
97125
context={'vpt_stderr':stderr})
98126
self.study.finalize(trial)
99-
self.study.save(self.results_directory)
127+
if self.results_directory is not None:
128+
self.study.save(self.results_directory)
129+
# Save in a way that's more flexible
130+
if self.progress_file is not None:
131+
with open(self.progress_file, 'wb') as file:
132+
pickle.dump(self.opt_observations, file)
100133

101134
def run_tests(self, test_ntrials, lyap_reps=20, parameters=None):
102135
"""

0 commit comments

Comments
 (0)