@@ -287,6 +287,7 @@ mutable struct RadauIIA9ConstantCache{F, Tab, Tol, Dt, U, JType} <:
287287 cont2:: U
288288 cont3:: U
289289 cont4:: U
290+ cont5:: U
290291 dtprev:: Dt
291292 W_γdt:: Dt
292293 status:: NLStatus
@@ -304,7 +305,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
304305 κ = alg. κ != = nothing ? convert (uToltype, alg. κ) : convert (uToltype, 1 // 100 )
305306 J = false .* _vec (rate_prototype) .* _vec (rate_prototype)'
306307
307- RadauIIA9ConstantCache (uf, tab, κ, one (uToltype), 10000 , u, u, u, u, dt, dt,
308+ RadauIIA9ConstantCache (uf, tab, κ, one (uToltype), 10000 , u, u, u, u, u, dt, dt,
308309 Convergence, J)
309310end
310311
@@ -333,6 +334,7 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
333334 cont2:: uType
334335 cont3:: uType
335336 cont4:: uType
337+ cont5:: uType
336338 du1:: rateType
337339 fsalfirst:: rateType
338340 k:: rateType
@@ -407,6 +409,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
407409 cont2 = zero (u)
408410 cont3 = zero (u)
409411 cont4 = zero (u)
412+ cont5 = zero (u)
410413
411414 fsalfirst = zero (rate_prototype)
412415 k = zero (rate_prototype)
@@ -462,11 +465,193 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
462465
463466 RadauIIA9Cache (u, uprev,
464467 z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
465- dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
468+ dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4, cont5,
466469 du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
467470 J, W1, W2, W3,
468471 uf, tab, κ, one (uToltype), 10000 ,
469472 tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
470473 linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
471474 Convergence, alg. step_limiter!)
472475end
476+
477+ mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} < :
478+ OrdinaryDiffEqConstantCache
479+ uf:: F
480+ tab:: Tab
481+ κ:: Tol
482+ ηold:: Tol
483+ iter:: Int
484+ cont:: Vector{U}
485+ dtprev:: Dt
486+ W_γdt:: Dt
487+ status:: NLStatus
488+ J:: JType
489+ end
490+
491+ function alg_cache (alg:: AdaptiveRadau , u, rate_prototype, :: Type{uEltypeNoUnits} ,
492+ :: Type{uBottomEltypeNoUnits} ,
493+ :: Type{tTypeNoUnits} , uprev, uprev2, f, t, dt, reltol, p, calck,
494+ :: Val{false} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
495+ uf = UDerivativeWrapper (f, t, p)
496+ uToltype = constvalue (uBottomEltypeNoUnits)
497+ num_stages = alg. num_stages
498+
499+ if (num_stages == 3 )
500+ tab = BigRadauIIA5Tableau (uToltype, constvalue (tTypeNoUnits))
501+ elseif (num_stages == 5 )
502+ tab = BigRadauIIA9Tableau (uToltype, constvalue (tTypeNoUnits))
503+ elseif (num_stages == 7 )
504+ tab = BigRadauIIA13Tableau (uToltype, constvalue (tTypeNoUnits))
505+ elseif iseven (num_stages) || num_stages < 3
506+ error (" num_stages must be odd and 3 or greater" )
507+ else
508+ tab = adaptiveRadauTableau (uToltype, constvalue (tTypeNoUnits), num_stages)
509+ end
510+
511+ cont = Vector {typeof(u)} (undef, num_stages)
512+ for i in 1 : num_stages
513+ cont[i] = zero (u)
514+ end
515+
516+ κ = alg. κ != = nothing ? convert (uToltype, alg. κ) : convert (uToltype, 1 // 100 )
517+ J = false .* _vec (rate_prototype) .* _vec (rate_prototype)'
518+
519+ AdaptiveRadauConstantCache (uf, tab, κ, one (uToltype), 10000 , cont, dt, dt,
520+ Convergence, J)
521+ end
522+
523+ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
524+ UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} < :
525+ FIRKMutableCache
526+ u:: uType
527+ uprev:: uType
528+ z:: Vector{uType}
529+ w:: Vector{uType}
530+ c_prime:: Vector{tType}
531+ dw1:: uType
532+ ubuff:: uType
533+ dw2:: Vector{cuType}
534+ cubuff:: Vector{cuType}
535+ dw:: Vector{uType}
536+ cont:: Vector{uType}
537+ derivatives:: Matrix{uType}
538+ du1:: rateType
539+ fsalfirst:: rateType
540+ ks:: Vector{rateType}
541+ k:: rateType
542+ fw:: Vector{rateType}
543+ J:: JType
544+ W1:: W1Type # real
545+ W2:: Vector{W2Type} # complex
546+ uf:: UF
547+ tab:: Tab
548+ κ:: Tol
549+ ηold:: Tol
550+ iter:: Int
551+ tmp:: uType
552+ atmp:: uNoUnitsType
553+ jac_config:: JC
554+ linsolve1:: F1 # real
555+ linsolve2:: Vector{F2} # complex
556+ rtol:: rTol
557+ atol:: aTol
558+ dtprev:: Dt
559+ W_γdt:: Dt
560+ status:: NLStatus
561+ step_limiter!:: StepLimiter
562+ end
563+
564+ function alg_cache (alg:: AdaptiveRadau , u, rate_prototype, :: Type{uEltypeNoUnits} ,
565+ :: Type{uBottomEltypeNoUnits} ,
566+ :: Type{tTypeNoUnits} , uprev, uprev2, f, t, dt, reltol, p, calck,
567+ :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
568+ uf = UJacobianWrapper (f, t, p)
569+ uToltype = constvalue (uBottomEltypeNoUnits)
570+ num_stages = alg. num_stages
571+
572+ if (num_stages == 3 )
573+ tab = BigRadauIIA5Tableau (uToltype, constvalue (tTypeNoUnits))
574+ elseif (num_stages == 5 )
575+ tab = BigRadauIIA9Tableau (uToltype, constvalue (tTypeNoUnits))
576+ elseif (num_stages == 7 )
577+ tab = BigRadauIIA13Tableau (uToltype, constvalue (tTypeNoUnits))
578+ elseif iseven (num_stages) || num_stages < 3
579+ error (" num_stages must be odd and 3 or greater" )
580+ else
581+ tab = adaptiveRadauTableau (uToltype, constvalue (tTypeNoUnits), num_stages)
582+ end
583+
584+ κ = alg. κ != = nothing ? convert (uToltype, alg. κ) : convert (uToltype, 1 // 100 )
585+
586+ z = Vector {typeof(u)} (undef, num_stages)
587+ w = Vector {typeof(u)} (undef, num_stages)
588+ for i in 1 : num_stages
589+ z[i] = w[i] = zero (u)
590+ end
591+
592+ c_prime = Vector {typeof(t)} (undef, num_stages) # time stepping
593+
594+ dw1 = zero (u)
595+ ubuff = zero (u)
596+ dw2 = [similar (u, Complex{eltype (u)}) for _ in 1 : (num_stages - 1 ) ÷ 2 ]
597+ recursivefill! .(dw2, false )
598+ cubuff = [similar (u, Complex{eltype (u)}) for _ in 1 : (num_stages - 1 ) ÷ 2 ]
599+ recursivefill! .(cubuff, false )
600+ dw = Vector {typeof(u)} (undef, num_stages - 1 )
601+
602+ cont = Vector {typeof(u)} (undef, num_stages)
603+ for i in 1 : num_stages
604+ cont[i] = zero (u)
605+ end
606+
607+ derivatives = Matrix {typeof(u)} (undef, num_stages, num_stages)
608+ for i in 1 : num_stages, j in 1 : num_stages
609+ derivatives[i, j] = zero (u)
610+ end
611+
612+ fsalfirst = zero (rate_prototype)
613+ fw = Vector {typeof(rate_prototype)} (undef, num_stages)
614+ ks = Vector {typeof(rate_prototype)} (undef, num_stages)
615+ for i in 1 : num_stages
616+ ks[i] = fw[i] = zero (rate_prototype)
617+ end
618+ k = ks[1 ]
619+
620+ J, W1 = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (true ))
621+ if J isa AbstractSciMLOperator
622+ error (" Non-concrete Jacobian not yet supported by AdaptiveRadau." )
623+ end
624+
625+ W2 = [similar (J, Complex{eltype (W1)}) for _ in 1 : (num_stages - 1 ) ÷ 2 ]
626+ recursivefill! .(W2, false )
627+
628+ du1 = zero (rate_prototype)
629+
630+ tmp = zero (u)
631+
632+ atmp = similar (u, uEltypeNoUnits)
633+ recursivefill! (atmp, false )
634+
635+ jac_config = build_jac_config (alg, f, uf, du1, uprev, u, zero (u), dw1)
636+
637+ linprob = LinearProblem (W1, _vec (ubuff); u0 = _vec (dw1))
638+ linsolve1 = init (linprob, alg. linsolve, alias_A = true , alias_b = true ,
639+ assumptions = LinearSolve. OperatorAssumptions (true ))
640+
641+ linsolve2 = [
642+ init (LinearProblem (W2[i], _vec (cubuff[i]); u0 = _vec (dw2[i])), alg. linsolve, alias_A = true , alias_b = true ,
643+ assumptions = LinearSolve. OperatorAssumptions (true )) for i in 1 : (num_stages - 1 ) ÷ 2 ]
644+
645+ rtol = reltol isa Number ? reltol : zero (reltol)
646+ atol = reltol isa Number ? reltol : zero (reltol)
647+
648+ AdaptiveRadauCache (u, uprev,
649+ z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
650+ du1, fsalfirst, ks, k, fw,
651+ J, W1, W2,
652+ uf, tab, κ, one (uToltype), 10000 , tmp,
653+ atmp, jac_config,
654+ linsolve1, linsolve2, rtol, atol, dt, dt,
655+ Convergence, alg. step_limiter!)
656+ end
657+
0 commit comments