1414from jax .lax import scan as _scan
1515import time , sys
1616
17- def get_integrator_code (integrationType ):
17+ def get_integrator_code (integrationType ): ## integrator type decoding routine
1818 """
1919 Convenience function for mapping integrator type string to ngc-learn's
2020 internal integer code value.
@@ -42,22 +42,20 @@ def get_integrator_code(integrationType):
4242 to RK-1/Euler routine" .format (integrationType ))
4343 return intgFlag
4444
45-
4645@jit
47- def _sum_combine (* args , ** kwargs ): ## fast co-routine for simple addition
48- sum = 0
49-
46+ def _sum_combine (* args , ** kwargs ): ## fast co-routine for simple addition/summation
47+ _sum = 0
5048 for arg , val in zip (args , kwargs .values ()):
51- sum = sum + val * arg
52- return sum
49+ _sum = _sum + val * arg
50+ return _sum
5351
5452@jit
5553def _step_forward (t , x , dx_dt , dt , x_scale ): ## internal step co-routine
5654 _t = t + dt
5755 _x = x * x_scale + dx_dt * dt
5856 return _t , _x
5957
60- @partial (jit , static_argnums = (2 , 3 , 4 , 5 , ) )
58+ @partial (jit , static_argnums = (2 , 3 , 5 )) #(2, 3, 4, 5 )
6159def step_euler (t , x , dfx , dt , params , x_scale = 1. ):
6260 """
6361 Iteratively integrates one step forward via the Euler method, i.e., a
@@ -81,14 +79,12 @@ def step_euler(t, x, dfx, dt, params, x_scale=1.):
8179 Returns:
8280 variable values iteratively integrated/advanced to next step (`t + dt`)
8381 """
84-
8582 carry = (t , x )
8683 next_state , * _ = _euler (carry , dfx , dt , params , x_scale = x_scale )
8784 _t , _x = next_state
88-
8985 return _t , _x
9086
91- @partial (jit , static_argnums = (1 , 2 , 3 , 4 , ))
87+ @partial (jit , static_argnums = (1 , 2 , 4 )) #(1, 2, 3, 4 ))
9288def _euler (carry , dfx , dt , params , x_scale = 1. ):
9389 """
9490 Iteratively integrates one step forward via the Euler method, i.e., a
@@ -111,17 +107,12 @@ def _euler(carry, dfx, dt, params, x_scale=1.):
111107 variable values iteratively integrated/advanced to next step (`t + dt`)
112108 """
113109 t , x = carry
114-
115110 dx_dt = dfx (t , x , params )
116111 _t , _x = _step_forward (t , x , dx_dt , dt , x_scale )
117-
118112 new_carry = (_t , _x )
119113 return new_carry , (new_carry , carry )
120114
121-
122-
123-
124-
115+ @partial (jit , static_argnums = (2 , 3 , 5 )) #(2, 3, 4, 5))
125116def step_heun (t , x , dfx , dt , params , x_scale = 1. ):
126117 """
127118 Iteratively integrates one step forward via Heun's method, i.e., a
@@ -155,23 +146,11 @@ def step_heun(t, x, dfx, dt, params, x_scale=1.):
155146 """
156147
157148 carry = (t , x )
158-
159149 next_state , * _ = _heun (carry , dfx , dt , params , x_scale = x_scale )
160-
161- #
162- # dx_dt = dfx(t, x, params)
163- #
164- # _t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
165- # _dx_dt = dfx(_t, _x, params)
166- # summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
167-
168- # _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
169150 _t , _x = next_state
170-
171151 return _t , _x
172152
173-
174- @partial (jit , static_argnums = (1 , 2 , 3 , 4 , ))
153+ @partial (jit , static_argnums = (1 , 2 , 4 )) #(1, 2, 3, 4, ))
175154def _heun (carry , dfx , dt , params , x_scale = 1. ):
176155 """
177156 Iteratively integrates one step forward via Heun's method, i.e., a
@@ -202,19 +181,15 @@ def _heun(carry, dfx, dt, params, x_scale=1.):
202181 variable values iteratively integrated/advanced to next step (`t + dt`)
203182 """
204183 t , x = carry
205-
206184 dx_dt = dfx (t , x , params )
207185 _t , _x = _step_forward (t , x , dx_dt , dt , x_scale )
208186 _dx_dt = dfx (_t , _x , params )
209187 summed_dx_dt = _sum_combine (dx_dt , _dx_dt , weight1 = 1 , weight2 = 1 )
210188 _ , _x = _step_forward (t , x , summed_dx_dt , dt * 0.5 , x_scale )
211-
212189 new_carry = (_t , _x )
213190 return new_carry , (new_carry , carry )
214191
215-
216-
217-
192+ @partial (jit , static_argnums = (2 , 3 , 5 )) #(2, 3, 4, 5))
218193def step_rk2 (t , x , dfx , dt , params , x_scale = 1. ):
219194 """
220195 Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -244,30 +219,12 @@ def step_rk2(t, x, dfx, dt, params, x_scale=1.):
244219 Returns:
245220 variable values iteratively integrated/advanced to next step (`t + dt`)
246221 """
247-
248222 carry = (t , x )
249223 next_state , * _ = _rk2 (carry , dfx , dt , params , x_scale = x_scale )
250224 _t , _x = next_state
251-
252- #
253- # dx_dt = dfx(t, x, params)
254- #
255- # _t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
256- # _dx_dt = dfx(_t, _x, params)
257- # summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
258-
259- # _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
260-
261-
262- # dfx_1 = dfx(t, x, params)
263- #
264- # t1, x1 = _step_forward(t, x, dfx_1, dt * 0.5, x_scale)
265- # dfx_2 = dfx(t1, x1, params)
266- # _t, _x = _step_forward(t, x, dfx_2, dt, x_scale)
267225 return _t , _x
268226
269-
270- @partial (jit , static_argnums = (1 , 2 , 3 , 4 , ))
227+ @partial (jit , static_argnums = (1 , 2 , 4 )) #(1, 2, 3, 4, ))
271228def _rk2 (carry , dfx , dt , params , x_scale = 1. ):
272229 """
273230 Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -296,22 +253,19 @@ def _rk2(carry, dfx, dt, params, x_scale=1.):
296253 variable values iteratively integrated/advanced to next step (`t + dt`)
297254 """
298255 t , x = carry
299-
300256 f_1 = dfx (t , x , params )
301257 t1 , x1 = _step_forward (t , x , f_1 , dt * 0.5 , x_scale )
302258 f_2 = dfx (t1 , x1 , params )
303259 _t , _x = _step_forward (t , x , f_2 , dt , x_scale )
304-
305260 new_carry = (_t , _x )
306261 return new_carry , (new_carry , carry )
307262
308-
309-
263+ @partial (jit , static_argnums = (2 , 3 , 5 )) #(2, 3, 4, 5))
310264def step_rk4 (t , x , dfx , dt , params , x_scale = 1. ):
311265 """
312266 Iteratively integrates one step forward via the midpoint method, i.e., a
313267 fourth-order Runge-Kutta (RK-4) step.
314- (Note: ngc-learn internally recognizes "rk4" or this routine)
268+ (Note: ngc-learn internally recognizes "rk4" for this routine)
315269
316270 | Reference:
317271 | Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
@@ -339,25 +293,9 @@ def step_rk4(t, x, dfx, dt, params, x_scale=1.):
339293 carry = (t , x )
340294 next_state , * _ = _rk4 (carry , dfx , dt , params , x_scale = x_scale )
341295 _t , _x = next_state
342-
343- # dfx_1 = dfx(t, x, params)
344- # t2, x2 = _step_forward(t, x, dfx_1, dt * 0.5, x_scale)
345- #
346- # dfx_2 = dfx(t2, x2, params)
347- # t3, x3 = _step_forward(t, x, dfx_2, dt * 0.5, x_scale)
348- #
349- # dfx_3 = dfx(t3, x3, params)
350- # t4, x4 = _step_forward(t, x, dfx_3, dt, x_scale)
351- #
352- # dfx_4 = dfx(t4, x4, params)
353- #
354- # _dx_dt = _sum_combine(dfx_1, dfx_2, dfx_3, dfx_4, w_f1=1, w_f2=2, w_f3=2, w_f4=1)
355- # _t, _x = _step_forward(t, x, _dx_dt / 6, dt, x_scale)
356296 return _t , _x
357297
358-
359-
360- @partial (jit , static_argnums = (1 , 2 , 3 , 4 , ))
298+ @partial (jit , static_argnums = (1 , 2 , 4 )) #(1, 2, 3, 4, ))
361299def _rk4 (carry , dfx , dt , params , x_scale = 1. ):
362300 """
363301 Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -385,28 +323,22 @@ def _rk4(carry, dfx, dt, params, x_scale=1.):
385323 Returns:
386324 variable values iteratively integrated/advanced to next step (`t + dt`)
387325 """
388-
389326 t , x = carry
390-
391- dfx_1 = dfx (t , x , params )
327+ ## carry out 4 steps of RK-4
328+ dfx_1 = dfx (t , x , params ) ## k1
392329 t2 , x2 = _step_forward (t , x , dfx_1 , dt * 0.5 , x_scale )
393-
394- dfx_2 = dfx (t2 , x2 , params )
330+ dfx_2 = dfx (t2 , x2 , params ) ## k2
395331 t3 , x3 = _step_forward (t , x , dfx_2 , dt * 0.5 , x_scale )
396-
397- dfx_3 = dfx (t3 , x3 , params )
332+ dfx_3 = dfx (t3 , x3 , params ) ## k3
398333 t4 , x4 = _step_forward (t , x , dfx_3 , dt , x_scale )
399-
400- dfx_4 = dfx (t4 , x4 , params )
401-
334+ dfx_4 = dfx (t4 , x4 , params ) ## k4
335+ ## produce final estimate and move forward
402336 _dx_dt = _sum_combine (dfx_1 , dfx_2 , dfx_3 , dfx_4 , w_f1 = 1 , w_f2 = 2 , w_f3 = 2 , w_f4 = 1 )
403337 _t , _x = _step_forward (t , x , _dx_dt / 6 , dt , x_scale )
404-
405338 new_carry = (_t , _x )
406339 return new_carry , (new_carry , carry )
407340
408-
409-
341+ @partial (jit , static_argnums = (2 , 3 , 5 )) #(2, 3, 4, 5))
410342def step_ralston (t , x , dfx , dt , params , x_scale = 1. ):
411343 """
412344 Iteratively integrates one step forward via Ralston's method, i.e., a
@@ -438,22 +370,12 @@ def step_ralston(t, x, dfx, dt, params, x_scale=1.):
438370 Returns:
439371 variable values iteratively integrated/advanced to next step (`t + dt`)
440372 """
441-
442373 carry = (t , x )
443- next_state , * _ = _rk4 (carry , dfx , dt , params , x_scale = x_scale )
374+ next_state , * _ = _ralston (carry , dfx , dt , params , x_scale = x_scale )
444375 _t , _x = next_state
445-
446- # dx_dt = dfx(t, x, params) ## k1
447- # tm, xm = _step_forward(t, x, dx_dt, dt * 0.75, x_scale)
448- # _dx_dt = dfx(tm, xm, params) ## k2
449- # ## Note: new step is a weighted combination of k1 and k2
450- # summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=(1./3.), weight2=(2./3.))
451- # _t, _x = _step_forward(t, x, summed_dx_dt, dt, x_scale)
452376 return _t , _x
453377
454-
455-
456- @partial (jit , static_argnums = (1 , 2 , 3 , 4 ,))
378+ @partial (jit , static_argnums = (1 , 2 , 4 )) #(1, 2, 3, 4,))
457379def _ralston (carry , dfx , dt , params , x_scale = 1. ):
458380 """
459381 Iteratively integrates one step forward via Ralston's method, i.e., a
@@ -485,22 +407,18 @@ def _ralston(carry, dfx, dt, params, x_scale=1.):
485407 """
486408
487409 t , x = carry
488-
489410 dx_dt = dfx (t , x , params ) ## k1
490411 tm , xm = _step_forward (t , x , dx_dt , dt * 0.75 , x_scale )
491412 _dx_dt = dfx (tm , xm , params ) ## k2
492413 ## Note: new step is a weighted combination of k1 and k2
493414 summed_dx_dt = _sum_combine (dx_dt , _dx_dt , weight1 = (1. / 3. ), weight2 = (2. / 3. ))
494415 _t , _x = _step_forward (t , x , summed_dx_dt , dt , x_scale )
495-
496416 new_carry = (_t , _x )
497417 return new_carry , (new_carry , carry )
498418
499419
500-
501420@partial (jit , static_argnums = (0 , 3 , 4 , 5 , 6 , 7 , 8 ))
502421def solve_ode (method_name , t0 , x0 , T , dfx , dt , params = None , x_scale = 1. , sols_only = True ):
503-
504422 if method_name == 'euler' :
505423 method = _euler
506424 elif method_name == 'heun' :
0 commit comments