Skip to content

Commit 8096320

Browse files
Add option on some functions to return used ResComp object
1 parent b0e146b commit 8096320

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

rescomp/optimizer/optimizer_controller.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, system, map_initial, prediction_type, method, res_ode=None,
3434
parallel=False, parallel_profile=None, **res_params):
3535
"""
3636
Arguments:
37-
system (string or template.System): the system to use. If not a template.System object, will attempt to load one using rescomp.optimizer.get_system(system)
37+
system (string or rescomp.optimizer.System): the system to use. If not a rescomp.optimizer.System object, will attempt to load one using rescomp.optimizer.get_system(system)
3838
map_initial (string): initial condition mapping for reservoir computer to use
3939
prediction_type (string): 'random' or 'continue'; prediction type to use while optimizing.
4040
method (string): training method; 'standard' or 'augmented'
@@ -141,7 +141,7 @@ def run_tests(self, test_ntrials, lyap_reps=20, parameters=None):
141141
results_dict["rand_deriv_fit"].append(rand_df)
142142
return results_dict
143143

144-
def generate_orbits(self, n_orbits, parameters=None):
144+
def generate_orbits(self, n_orbits, parameters=None, return_rescomp=False):
145145
"""
146146
Trains a reservoir computer and has it predict, using the given hyperparameters
147147
If parameters are not specified, uses the optimized hyperparameters.
@@ -158,12 +158,12 @@ def generate_orbits(self, n_orbits, parameters=None):
158158

159159
if self.parallel:
160160
results, _ = self._run_n_times_parallel(n_orbits, create_orbit,
161-
self.system, self.prediction_type, **self.res_params, **parameters)
161+
self.system, self.prediction_type, **self.res_params, **parameters, return_rescomp=return_rescomp)
162162
#Collapse into single list of outputs
163163
results = [item for sublist in results for item in sublist]
164164
else:
165165
results = self._run_n_times(n_orbits, create_orbit,
166-
self.system, self.prediction_type, **self.res_params, **parameters)
166+
self.system, self.prediction_type, **self.res_params, **parameters, return_rescomp=return_rescomp)
167167

168168
return results
169169

rescomp/optimizer/optimizer_functions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def meanlyap(system, rcomp, pre, r0, ts, pert_size=1e-6, lyap_reps=20):
120120
lam += rc.lyapunov(ts[:i], pre[:i, :], predelta[:i, :], delta0)
121121
return lam / lyap_reps
122122

123-
def create_orbit(*args, **kwargs):
123+
def create_orbit(*args, return_rescomp=False, **kwargs):
124124
"""
125125
Trains a reservoir computer and has it predict, using the given arguments
126126
@@ -148,7 +148,10 @@ def create_orbit(*args, **kwargs):
148148
init_cond = make_initial(pred_type, rcomp, Uts)
149149
pre = rcomp_prediction(system, rcomp, ts, init_cond)
150150

151-
return tr, Utr, ts, Uts, pre
151+
if return_rescomp:
152+
return rcomp, tr, Utr, ts, Uts, pre
153+
else:
154+
return tr, Utr, ts, Uts, pre
152155

153156
#########################
154157
## Rescomp creation

0 commit comments

Comments
 (0)