@@ -41,20 +41,12 @@ def __init__(self, weights, nodes, matrix):
4141
4242 self .tleft = 0.0
4343 self .tright = 1.0
44- self .num_solution_stages = 0 if self .globally_stiffly_accurate else 1
45- self .num_nodes = matrix .shape [0 ] + self .num_solution_stages
44+ self .num_nodes = matrix .shape [0 ]
4645 self .weights = weights
4746
48- if self .globally_stiffly_accurate :
49- # For globally stiffly accurate methods, the last row of the Butcher tableau is the same as the weights.
50- self .nodes = np .append ([0 ], nodes )
51- self .Qmat = np .zeros ([self .num_nodes + 1 , self .num_nodes + 1 ])
52- self .Qmat [1 :, 1 :] = matrix
53- else :
54- self .nodes = np .append (np .append ([0 ], nodes ), [1 ])
55- self .Qmat = np .zeros ([self .num_nodes + 1 , self .num_nodes + 1 ])
56- self .Qmat [1 :- 1 , 1 :- 1 ] = matrix
57- self .Qmat [- 1 , 1 :- 1 ] = weights # this is for computing the solution to the step from the previous stages
47+ self .nodes = np .append ([0 ], nodes )
48+ self .Qmat = np .zeros ([self .num_nodes + 1 , self .num_nodes + 1 ])
49+ self .Qmat [1 :, 1 :] = matrix
5850
5951 self .left_is_node = True
6052 self .right_is_node = self .nodes [- 1 ] == self .tright
@@ -67,7 +59,7 @@ def __init__(self, weights, nodes, matrix):
6759 self .delta_m [0 ] = self .nodes [0 ] - self .tleft
6860
6961 # check if the RK scheme is implicit
70- self .implicit = any (matrix [i , i ] != 0 for i in range (self .num_nodes - self . num_solution_stages ))
62+ self .implicit = any (matrix [i , i ] != 0 for i in range (self .num_nodes ))
7163
7264
7365class ButcherTableauEmbedded (object ):
@@ -103,21 +95,20 @@ def __init__(self, weights, nodes, matrix):
10395 raise ParameterError (f'Incompatible number of nodes! Need { matrix .shape [0 ]} , got { len (nodes )} ' )
10496
10597 # Set number of nodes, left and right interval boundaries
106- self .num_solution_stages = 2
107- self .num_nodes = matrix .shape [0 ] + self .num_solution_stages
98+ self .num_nodes = matrix .shape [0 ]
10899 self .tleft = 0.0
109100 self .tright = 1.0
110101
111102 self .nodes = np .append (np .append ([0 ], nodes ), [1 , 1 ])
112103 self .weights = weights
113104 self .Qmat = np .zeros ([self .num_nodes + 1 , self .num_nodes + 1 ])
114- self .Qmat [1 :- 2 , 1 :- 2 ] = matrix
115- self .Qmat [- 1 , 1 :- 2 ] = weights [0 ] # this is for computing the higher order solution
116- self .Qmat [- 2 , 1 :- 2 ] = weights [1 ] # this is for computing the lower order solution
105+ self .Qmat [1 :, 1 :] = matrix
117106
118107 self .left_is_node = True
119108 self .right_is_node = self .nodes [- 1 ] == self .tright
120109
110+ self .globally_stiffly_accurate = np .allclose (matrix [- 1 ], weights [0 ])
111+
121112 # compute distances between the nodes
122113 if self .num_nodes > 1 :
123114 self .delta_m = self .nodes [1 :] - self .nodes [:- 1 ]
@@ -126,7 +117,7 @@ def __init__(self, weights, nodes, matrix):
126117 self .delta_m [0 ] = self .nodes [0 ] - self .tleft
127118
128119 # check if the RK scheme is implicit
129- self .implicit = any (matrix [i , i ] != 0 for i in range (self .num_nodes - self . num_solution_stages ))
120+ self .implicit = any (matrix [i , i ] != 0 for i in range (self .num_nodes ))
130121
131122
132123class RungeKutta (Sweeper ):
@@ -292,8 +283,7 @@ def update_nodes(self):
292283 lvl .u [m + 1 ][:] = rhs [:]
293284
294285 # update function values (we don't usually need to evaluate the RHS at the solution of the step)
295- if m < M - self .coll .num_solution_stages or self .params .eval_rhs_at_right_boundary :
296- lvl .f [m + 1 ] = prob .eval_f (lvl .u [m + 1 ], lvl .time + lvl .dt * self .coll .nodes [m + 1 ])
286+ lvl .f [m + 1 ] = prob .eval_f (lvl .u [m + 1 ], lvl .time + lvl .dt * self .coll .nodes [m + 1 ])
297287
298288 # indicate presence of new values at this level
299289 lvl .status .updated = True
@@ -304,7 +294,22 @@ def compute_end_point(self):
304294 """
305295 In this Runge-Kutta implementation, the solution to the step is always stored in the last node
306296 """
307- self .level .uend = self .level .u [- 1 ]
297+ if self .coll .globally_stiffly_accurate :
298+ self .level .uend = self .level .u [- 1 ]
299+ if type (self .coll ) == ButcherTableauEmbedded :
300+ self .u_secondary = self .level .u [0 ].copy ()
301+ for w2 , k in zip (self .coll .weights [1 ], self .level .f [1 :]):
302+ self .u_secondary += self .level .dt * w2 * k
303+ else :
304+ self .level .uend = self .level .u [0 ].copy ()
305+ if type (self .coll ) == ButcherTableau :
306+ for w , k in zip (self .coll .weights , self .level .f [1 :]):
307+ self .level .uend += self .level .dt * w * k
308+ elif type (self .coll ) == ButcherTableauEmbedded :
309+ self .u_secondary = self .level .u [0 ].copy ()
310+ for w1 , w2 , k in zip (self .coll .weights [0 ], self .coll .weights [1 ], self .level .f [1 :]):
311+ self .level .uend += self .level .dt * w1 * k
312+ self .u_secondary += self .level .dt * w2 * k
308313
309314 @property
310315 def level (self ):
@@ -356,6 +361,7 @@ class RungeKuttaIMEX(RungeKutta):
356361 """
357362
358363 matrix_explicit = None
364+ weights_explicit = None
359365 ButcherTableauClass_explicit = ButcherTableau
360366
361367 def __init__ (self , params ):
@@ -366,6 +372,7 @@ def __init__(self, params):
366372 params: parameters for the sweeper
367373 """
368374 super ().__init__ (params )
375+ type(self ).weights_explicit = self .weights if self .weights_explicit is None else self .weights_explicit
369376 self .coll_explicit = self .get_Butcher_tableau_explicit ()
370377 self .QE = self .coll_explicit .Qmat
371378
@@ -388,7 +395,7 @@ def predict(self):
388395
389396 @classmethod
390397 def get_Butcher_tableau_explicit (cls ):
391- return cls .ButcherTableauClass_explicit (cls .weights , cls .nodes , cls .matrix_explicit )
398+ return cls .ButcherTableauClass_explicit (cls .weights_explicit , cls .nodes , cls .matrix_explicit )
392399
393400 def integrate (self ):
394401 """
@@ -448,15 +455,41 @@ def update_nodes(self):
448455 else :
449456 lvl .u [m + 1 ][:] = rhs [:]
450457
451- # update function values (we don't usually need to evaluate the RHS at the solution of the step)
452- if m < M - self .coll .num_solution_stages or self .params .eval_rhs_at_right_boundary :
453- lvl .f [m + 1 ] = prob .eval_f (lvl .u [m + 1 ], lvl .time + lvl .dt * self .coll .nodes [m + 1 ])
458+ # update function values
459+ lvl .f [m + 1 ] = prob .eval_f (lvl .u [m + 1 ], lvl .time + lvl .dt * self .coll .nodes [m + 1 ])
454460
455461 # indicate presence of new values at this level
456462 lvl .status .updated = True
457463
458464 return None
459465
466+ def compute_end_point (self ):
467+ """
468+ In this Runge-Kutta implementation, the solution to the step is always stored in the last node
469+ """
470+ if self .coll .globally_stiffly_accurate and self .coll_explicit .globally_stiffly_accurate :
471+ self .level .uend = self .level .u [- 1 ]
472+ if type (self .coll ) == ButcherTableauEmbedded :
473+ self .u_secondary = self .level .u [0 ].copy ()
474+ for w2 , w2E , k in zip (self .coll .weights [1 ], self .coll_explicit .weights [1 ], self .level .f [1 :]):
475+ self .u_secondary += self .level .dt * (w2 * k .impl + w2E * k .expl )
476+ else :
477+ self .level .uend = self .level .u [0 ].copy ()
478+ if type (self .coll ) == ButcherTableau :
479+ for w , wE , k in zip (self .coll .weights , self .coll_explicit .weights , self .level .f [1 :]):
480+ self .level .uend += self .level .dt * (w * k .impl + wE * k .expl )
481+ elif type (self .coll ) == ButcherTableauEmbedded :
482+ self .u_secondary = self .level .u [0 ].copy ()
483+ for w1 , w2 , w1E , w2E , k in zip (
484+ self .coll .weights [0 ],
485+ self .coll .weights [1 ],
486+ self .coll_explicit .weights [0 ],
487+ self .coll_explicit .weights [1 ],
488+ self .level .f [1 :],
489+ ):
490+ self .level .uend += self .level .dt * (w1 * k .impl + w1E * k .expl )
491+ self .u_secondary += self .level .dt * (w2 * k .impl + w2E * k .expl )
492+
460493
461494class ForwardEuler (RungeKutta ):
462495 """
@@ -480,6 +513,14 @@ class BackwardEuler(RungeKutta):
480513 nodes , weights , matrix = generator .genCoeffs ()
481514
482515
516+ class IMEXEuler (RungeKuttaIMEX ):
517+ nodes = BackwardEuler .nodes
518+ weights = BackwardEuler .weights
519+
520+ matrix = BackwardEuler .matrix
521+ matrix_explicit = ForwardEuler .matrix
522+
523+
483524class CrankNicolson (RungeKutta ):
484525 """
485526 Implicit Runge-Kutta method of second order, A-stable.
@@ -521,8 +562,13 @@ class Heun_Euler(RungeKutta):
521562 Second order explicit embedded Runge-Kutta method.
522563 """
523564
565+ ButcherTableauClass = ButcherTableauEmbedded
566+
524567 generator = RK_SCHEMES ["HEUN" ]()
525- nodes , weights , matrix = generator .genCoeffs ()
568+ nodes , _weights , matrix = generator .genCoeffs ()
569+ weights = np .zeros ((2 , len (_weights )))
570+ weights [0 ] = _weights
571+ weights [1 ] = matrix [- 1 ]
526572
527573 @classmethod
528574 def get_update_order (cls ):
@@ -697,3 +743,41 @@ class ARK548L2SA(RungeKuttaIMEX):
697743 @classmethod
698744 def get_update_order (cls ):
699745 return 5
746+
747+
748+ class ARK324L2SAERK (RungeKutta ):
749+ generator = RK_SCHEMES ["ARK324L2SAERK" ]()
750+ nodes , weights , matrix = generator .genCoeffs (embedded = True )
751+ ButcherTableauClass = ButcherTableauEmbedded
752+
753+ @classmethod
754+ def get_update_order (cls ):
755+ return 3
756+
757+
758+ class ARK324L2SAESDIRK (ARK324L2SAERK ):
759+ generator = RK_SCHEMES ["ARK324L2SAESDIRK" ]()
760+ matrix = generator .Q
761+
762+
763+ class ARK32 (RungeKuttaIMEX ):
764+ ButcherTableauClass = ButcherTableauEmbedded
765+ ButcherTableauClass_explicit = ButcherTableauEmbedded
766+
767+ nodes = ARK324L2SAESDIRK .nodes
768+ weights = ARK324L2SAESDIRK .weights
769+
770+ matrix = ARK324L2SAESDIRK .matrix
771+ matrix_explicit = ARK324L2SAERK .matrix
772+
773+ @classmethod
774+ def get_update_order (cls ):
775+ return 3
776+
777+
778+ class ARK2 (RungeKuttaIMEX ):
779+ generator_IMP = RK_SCHEMES ["ARK222EDIRK" ]()
780+ generator_EXP = RK_SCHEMES ["ARK222ERK" ]()
781+
782+ nodes , weights , matrix = generator_IMP .genCoeffs ()
783+ _ , weights_explicit , matrix_explicit = generator_EXP .genCoeffs ()
0 commit comments