@@ -181,28 +181,35 @@ def trained_rcomp(system, tr, Utr, res_ode=None, **opt_params):
181181 rcomp (ResComp): Trained reservoir computer
182182 """
183183 resprms , methodprms , otherprms = build_params (system , opt_params )
184- if system .is_driven :
185- rcomp = rc .DrivenResComp (** resprms )
186- else :
187- rcomp = rc .ResComp (** resprms )
188184
189- if res_ode is not None :
190- #ResComp and DrivenResComp call the ODE functions different things
185+ if 'ResComp' in opt_params .keys ():
186+ #Use provided rescomp constructor if given
187+ #Should handle all parameters
188+ otherprms .pop ('ResComp' )
189+ rcomp = opt_params ['ResComp' ](** resprms , ** otherprms )
190+ else :
191191 if system .is_driven :
192- bind_function (rcomp , res_ode ['res_ode' ], 'res_f' )
193- bind_function (rcomp , res_ode ['trained_res_ode' ], 'res_pred_f' )
192+ rcomp = rc .DrivenResComp (** resprms )
194193 else :
195- bind_function (rcomp , res_ode ['res_ode' ], 'res_ode' )
196- bind_function (rcomp , res_ode ['trained_res_ode' ], 'trained_res_ode' )
197- initial = res_ode .get ('initial_condition' )
198- if initial is not None :
199- bind_function (rcomp , initial , 'initial_condition' )
200-
201- for var in otherprms .keys ():
202- if callable (otherprms [var ]):
203- bind_function (rcomp , otherprms [var ], var )
204- else :
205- setattr (rcomp , var , otherprms [var ])
194+ rcomp = rc .ResComp (** resprms )
195+
196+ if res_ode is not None :
197+ #ResComp and DrivenResComp call the ODE functions different things
198+ if system .is_driven :
199+ bind_function (rcomp , res_ode ['res_ode' ], 'res_f' )
200+ bind_function (rcomp , res_ode ['trained_res_ode' ], 'res_pred_f' )
201+ else :
202+ bind_function (rcomp , res_ode ['res_ode' ], 'res_ode' )
203+ bind_function (rcomp , res_ode ['trained_res_ode' ], 'trained_res_ode' )
204+ initial = res_ode .get ('initial_condition' )
205+ if initial is not None :
206+ bind_function (rcomp , initial , 'initial_condition' )
207+
208+ for var in otherprms .keys ():
209+ if callable (otherprms [var ]):
210+ bind_function (rcomp , otherprms [var ], var )
211+ else :
212+ setattr (rcomp , var , otherprms [var ])
206213
207214 if system .is_driven :
208215 rcomp .train (tr , * Utr , ** methodprms )
0 commit comments