@@ -108,6 +108,16 @@ function initialize!(integrator, cache::RadauIIA7Cache)
108108 integrator. k[2 ] = integrator. fsallast
109109 integrator. f (integrator. fsalfirst, integrator. uprev, integrator. p, integrator. t)
110110 integrator. stats. nf += 1
111+ if integrator. opts. adaptive
112+ @unpack abstol, reltol = integrator. opts
113+ if reltol isa Number
114+ cache. rtol = reltol^ (2 / 3 ) / 10
115+ cache. atol = cache. rtol * (abstol / reltol)
116+ else
117+ @. . broadcast= false cache. rtol= reltol^ (2 / 3 ) / 10
118+ @. . broadcast= false cache. atol= cache. rtol * (abstol / reltol)
119+ end
120+ end
111121 nothing
112122end
113123
459469 atmp2 = calculate_residuals (dw2, uprev, u, atol, rtol, internalnorm, t)
460470 atmp3 = calculate_residuals (dw3, uprev, u, atol, rtol, internalnorm, t)
461471 ndw = internalnorm (atmp1, t) + internalnorm (atmp2, t) + internalnorm (atmp3, t)
462-
463472 # check divergence (not in initial step)
464473 if iter > 1
465474 θ = ndw / ndwprev
789798 mass_matrix = integrator. f. mass_matrix
790799
791800 # precalculations rtol pow is (num stages + 1)/(2*num stages)
792- rtol = @. . broadcast= false reltol# ^(5/8)/10
801+ rtol = @. . broadcast= false reltol^ (5 / 8 )/ 10
793802 atol = @. . broadcast= false rtol* (abstol / reltol)
794803 c1m1 = c1 - 1
795804 c2m1 = c2 - 1
859868 ff2 = f (uprev + z2, p, t + c2 * dt)
860869 ff3 = f (uprev + z3, p, t + c3 * dt)
861870 ff4 = f (uprev + z4, p, t + c4 * dt)
862- ff5 = f (uprev + z4 , p, t + dt) # c5 = 1
871+ ff5 = f (uprev + z5 , p, t + dt) # c5 = 1
863872 integrator. stats. nf += 5
864873
865874 fw1 = @. . broadcast= false TI11* ff1 + TI12* ff2 + TI13* ff3 + TI14* ff4 + TI15* ff5
904913 atmp4 = calculate_residuals (dw4, uprev, u, atol, rtol, internalnorm, t)
905914 atmp5 = calculate_residuals (dw5, uprev, u, atol, rtol, internalnorm, t)
906915 ndw = internalnorm (atmp1, t) + internalnorm (atmp2, t) + internalnorm (atmp3, t) + internalnorm (atmp4, t) + internalnorm (atmp5, t)
907-
908916 # check divergence (not in initial step)
909917 if iter > 1
910918 θ = ndw / ndwprev
939947 break
940948 end
941949 end
950+
942951 if fail_convergence
943952 integrator. force_stepfail = true
944953 integrator. stats. nnonlinconvfail += 1
@@ -948,15 +957,13 @@ end
948957 cache. iter = iter
949958
950959 u = @. . broadcast= false uprev + z5
951- #=
960+
952961 if adaptive
953- e1dt, e2dt, e3dt = e1 / dt, e2 / dt, e3 / dt
954- tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3
962+ e1dt, e2dt, e3dt, e4dt, e5dt = e1 / dt, e2 / dt, e3 / dt, e4 / dt, e5 / dt
963+ tmp = @. . broadcast= false e1dt* z1+ e2dt* z2+ e3dt* z3+ e4dt * z4 + e5dt * z5
955964 mass_matrix != I && (tmp = mass_matrix * tmp)
956965 utilde = @. . broadcast= false integrator. fsalfirst+ tmp
957966 alg. smooth_est && (utilde = LU1 \ utilde; integrator. stats. nsolve += 1 )
958- # RadauIIA5 needs a transformed rtol and atol see
959- # https://github.com/luchr/ODEInterface.jl/blob/0bd134a5a358c4bc13e0fb6a90e27e4ee79e0115/src/radau5.f#L399-L421
960967 atmp = calculate_residuals (utilde, uprev, u, atol, rtol, internalnorm, t)
961968 integrator. EEst = internalnorm (atmp, t)
962969
970977 integrator. EEst = internalnorm (atmp, t)
971978 end
972979 end
973- =#
980+
974981 if integrator. EEst <= oneunit (integrator. EEst)
975982 cache. dtprev = dt
976983 if alg. extrapolant != :constant
@@ -1002,10 +1009,10 @@ end
10021009 @unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache. tab
10031010 @unpack κ, cont1, cont2, cont3, cont4 = cache
10041011 @unpack z1, z2, z3, z4, z5, w1, w2, w3, w4, w5 = cache
1005- @unpack dw1, ubuff, dw23, dw45, cubuff = cache
1006- @unpack k, k2, k3, k4, fw1, fw2, fw3, fw4, fw5 = cache
1012+ @unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
1013+ @unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
10071014 @unpack J, W1, W2, W3 = cache
1008- tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
1015+ @unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
10091016 @unpack internalnorm, abstol, reltol, adaptive = integrator. opts
10101017 alg = unwrap_alg (integrator, true )
10111018 @unpack maxiters = alg
@@ -1023,7 +1030,7 @@ end
10231030 c2mc4 = c2 - c4
10241031 c3mc4 = c3 - c4
10251032
1026- γdt, αdt, βdt = γ / dt, α / dt, β / dt
1033+ γdt, α1dt, β1dt, α2dt, β2dt = γ / dt, α1 / dt, β1 / dt, α2 / dt, β2 / dt
10271034 (new_jac = do_newJ (integrator, alg, cache, repeat_step)) &&
10281035 (calc_J! (J, integrator, cache); cache. W_γdt = dt)
10291036 if (new_W = do_newW (integrator, alg, new_jac, cache. W_γdt))
@@ -1130,7 +1137,7 @@ end
11301137 Mw5 = z5
11311138 end
11321139
1133- @. . broadcast= false ubuff= fw1 - γdt * Mw1
1140+ @. . broadcast= false ubuff = fw1 - γdt * Mw1
11341141 needfactor = iter == 1 && new_W
11351142
11361143 linsolve1 = cache. linsolve1
@@ -1145,26 +1152,29 @@ end
11451152
11461153 cache. linsolve1 = linres1. cache
11471154
1148- @. . broadcast= false cubuff= complex (fw2 - αdt * Mw2 + βdt * Mw3,
1149- fw3 - βdt * Mw2 - αdt * Mw3)
1155+ @. . broadcast= false cubuff1 = complex (fw2 - α1dt * Mw2 + β1dt * Mw3, fw3 - β1dt * Mw2 - α1dt * Mw3)
11501156
11511157 linsolve2 = cache. linsolve2
11521158
11531159 if needfactor
1154- linres2 = dolinsolve (integrator, linsolve2; A = W2, b = _vec (cubuff ),
1160+ linres2 = dolinsolve (integrator, linsolve2; A = W2, b = _vec (cubuff1 ),
11551161 linu = _vec (dw23))
11561162 else
1157- linres2 = dolinsolve (integrator, linsolve2; A = nothing , b = _vec (cubuff ),
1163+ linres2 = dolinsolve (integrator, linsolve2; A = nothing , b = _vec (cubuff1 ),
11581164 linu = _vec (dw23))
11591165 end
11601166
11611167 cache. linsolve2 = linres2. cache
11621168
1169+ @. . broadcast= false cubuff2 = complex (fw4 - α2dt * Mw4 + β2dt * Mw5, fw5 - β2dt * Mw4 - α2dt * Mw5)
1170+
1171+ linsolve3 = cache. linsolve3
1172+
11631173 if needfactor
1164- linres3 = dolinsolve (integrator, linsolve2 ; A = W3, b = _vec (cubuff ),
1174+ linres3 = dolinsolve (integrator, linsolve3 ; A = W3, b = _vec (cubuff2 ),
11651175 linu = _vec (dw45))
11661176 else
1167- linres3 = dolinsolve (integrator, linsolve2 ; A = nothing , b = _vec (cubuff ),
1177+ linres3 = dolinsolve (integrator, linsolve3 ; A = nothing , b = _vec (cubuff2 ),
11681178 linu = _vec (dw45))
11691179 end
11701180
@@ -1242,8 +1252,8 @@ end
12421252 #=
12431253 if adaptive
12441254 utilde = w2
1245- e1dt, e2dt, e3dt = e1 / dt, e2 / dt, e3 / dt
1246- @.. broadcast=false tmp=e1dt * z1 + e2dt * z2 + e3dt * z3
1255+ e1dt, e2dt, e3dt, e4dt, e5dt = e1 / dt, e2 / dt, e3 / dt, e4 / dt, e5 / dt
1256+ @.. broadcast=false tmp=e1dt * z1 + e2dt * z2 + e3dt * z3 + e4dt * z4 + e5dt * z5
12471257 mass_matrix != I && (mul!(w1, mass_matrix, tmp); copyto!(tmp, w1))
12481258 @.. broadcast=false ubuff=integrator.fsalfirst + tmp
12491259
@@ -1282,16 +1292,16 @@ end
12821292 if integrator. EEst <= oneunit (integrator. EEst)
12831293 cache. dtprev = dt
12841294 if alg. extrapolant != :constant
1285- @. . broadcast = false cache. cont1 = (z4 - z5) / c4m1
1286- @. . broadcast = false tmp1 = (z3 - z4) / c3mc4
1287- @. . broadcast = false cache. cont2 = (tmp1 - cache. cont1) / c3m1
1288- @. . broadcast = false tmp2 = (z2 - z3) / c2mc3
1289- @. . broadcast = false tmp3 = (tmp2 - tmp) / c2mc4
1290- @. . broadcast = false cache. cont3 = (tmp3 - cache. cont2) / c2m1
1291- @. . broadcast = false tmp4 = (z1 - z2) / c1mc2
1292- @. . broadcast = false tmp5 = (tmp4 - tmp2) / c1mc3
1293- @. . broadcast = false tmp6 = (tmp5 - tmp3) / c1mc4
1294- @. . broadcast = false cache. cont4 = (tmp6 - cache. cont3) / c1m1
1295+ @. . cache. cont1 = (z4 - z5) / c4m1
1296+ @. . tmp1 = (z3 - z4) / c3mc4
1297+ @. . cache. cont2 = (tmp1 - cache. cont1) / c3m1
1298+ @. . tmp2 = (z2 - z3) / c2mc3
1299+ @. . tmp3 = (tmp2 - tmp) / c2mc4
1300+ @. . cache. cont3 = (tmp3 - cache. cont2) / c2m1
1301+ @. . tmp4 = (z1 - z2) / c1mc2
1302+ @. . tmp5 = (tmp4 - tmp2) / c1mc3
1303+ @. . tmp6 = (tmp5 - tmp3) / c1mc4
1304+ @. . cache. cont4 = (tmp6 - cache. cont3) / c1m1
12951305 end
12961306 end
12971307
0 commit comments