@@ -37,17 +37,24 @@ def __init__(self, weights, nodes, matrix):
3737 elif len (nodes ) != matrix .shape [0 ]:
3838 raise ParameterError (f'Incompatible number of nodes! Need { matrix .shape [0 ]} , got { len (nodes )} ' )
3939
40- # Set number of nodes, left and right interval boundaries
41- self .num_solution_stages = 1
42- self .num_nodes = matrix .shape [0 ] + self .num_solution_stages
40+ self .globally_stiffly_accurate = np .allclose (matrix [- 1 ], weights )
41+
4342 self .tleft = 0.0
4443 self .tright = 1.0
45-
46- self .nodes = np . append ( np . append ( [0 ], nodes ), [ 1 ])
44+ self . num_solution_stages = 0 if self . globally_stiffly_accurate else 1
45+ self .num_nodes = matrix . shape [0 ] + self . num_solution_stages
4746 self .weights = weights
48- self .Qmat = np .zeros ([self .num_nodes + 1 , self .num_nodes + 1 ])
49- self .Qmat [1 :- 1 , 1 :- 1 ] = matrix
50- self .Qmat [- 1 , 1 :- 1 ] = weights # this is for computing the solution to the step from the previous stages
47+
48+ if self .globally_stiffly_accurate :
49+ # For globally stiffly accurate methods, the last row of the Butcher tableau is the same as the weights.
50+ self .nodes = np .append ([0 ], nodes )
51+ self .Qmat = np .zeros ([self .num_nodes + 1 , self .num_nodes + 1 ])
52+ self .Qmat [1 :, 1 :] = matrix
53+ else :
54+ self .nodes = np .append (np .append ([0 ], nodes ), [1 ])
55+ self .Qmat = np .zeros ([self .num_nodes + 1 , self .num_nodes + 1 ])
56+ self .Qmat [1 :- 1 , 1 :- 1 ] = matrix
57+ self .Qmat [- 1 , 1 :- 1 ] = weights # this is for computing the solution to the step from the previous stages
5158
5259 self .left_is_node = True
5360 self .right_is_node = self .nodes [- 1 ] == self .tright
@@ -277,7 +284,7 @@ def update_nodes(self):
277284 rhs += lvl .dt * self .QI [m + 1 , j ] * self .get_full_f (lvl .f [j ])
278285
279286 # implicit solve with prefactor stemming from the diagonal of Qd, use previous stage as initial guess
280- if self .coll . implicit :
287+ if self .QI [ m + 1 , m + 1 ] != 0 :
281288 lvl .u [m + 1 ][:] = prob .solve_system (
282289 rhs , lvl .dt * self .QI [m + 1 , m + 1 ], lvl .u [m ], lvl .time + lvl .dt * self .coll .nodes [m + 1 ]
283290 )
0 commit comments