Skip to content

Commit 600da47

Browse files
Add way to use custom ResComp class
1 parent 8096320 commit 600da47

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

rescomp/optimizer/optimizer_functions.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -181,28 +181,35 @@ def trained_rcomp(system, tr, Utr, res_ode=None, **opt_params):
181181
rcomp (ResComp): Trained reservoir computer
182182
"""
183183
resprms, methodprms, otherprms = build_params(system, opt_params)
184-
if system.is_driven:
185-
rcomp = rc.DrivenResComp(**resprms)
186-
else:
187-
rcomp = rc.ResComp(**resprms)
188184

189-
if res_ode is not None:
190-
#ResComp and DrivenResComp call the ODE functions different things
185+
if 'ResComp' in opt_params.keys():
186+
#Use provided rescomp constructor if given
187+
#Should handle all parameters
188+
otherprms.pop('ResComp')
189+
rcomp = opt_params['ResComp'](**resprms, **otherprms)
190+
else:
191191
if system.is_driven:
192-
bind_function(rcomp, res_ode['res_ode'], 'res_f')
193-
bind_function(rcomp, res_ode['trained_res_ode'], 'res_pred_f')
192+
rcomp = rc.DrivenResComp(**resprms)
194193
else:
195-
bind_function(rcomp, res_ode['res_ode'], 'res_ode')
196-
bind_function(rcomp, res_ode['trained_res_ode'], 'trained_res_ode')
197-
initial = res_ode.get('initial_condition')
198-
if initial is not None:
199-
bind_function(rcomp, initial, 'initial_condition')
200-
201-
for var in otherprms.keys():
202-
if callable(otherprms[var]):
203-
bind_function(rcomp, otherprms[var], var)
204-
else:
205-
setattr(rcomp, var, otherprms[var])
194+
rcomp = rc.ResComp(**resprms)
195+
196+
if res_ode is not None:
197+
#ResComp and DrivenResComp call the ODE functions different things
198+
if system.is_driven:
199+
bind_function(rcomp, res_ode['res_ode'], 'res_f')
200+
bind_function(rcomp, res_ode['trained_res_ode'], 'res_pred_f')
201+
else:
202+
bind_function(rcomp, res_ode['res_ode'], 'res_ode')
203+
bind_function(rcomp, res_ode['trained_res_ode'], 'trained_res_ode')
204+
initial = res_ode.get('initial_condition')
205+
if initial is not None:
206+
bind_function(rcomp, initial, 'initial_condition')
207+
208+
for var in otherprms.keys():
209+
if callable(otherprms[var]):
210+
bind_function(rcomp, otherprms[var], var)
211+
else:
212+
setattr(rcomp, var, otherprms[var])
206213

207214
if system.is_driven:
208215
rcomp.train(tr, *Utr, **methodprms)

0 commit comments

Comments
 (0)