Skip to content

Commit 5acab16

Browse files
edits
1 parent e0c8857 commit 5acab16

File tree

5 files changed

+309
-227
lines changed

5 files changed

+309
-227
lines changed

lib/OrdinaryDiffEqFIRK/src/firk_addsteps.jl

Lines changed: 107 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ return
381381
end
382382

383383
function _ode_addsteps!(integrator, cache::RadauIIA5Cache, repeat_step = false)
384-
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
384+
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst, k = integrator
385385
@unpack T11, T12, T13, T21, T22, T23, T31, TI11, TI12, TI13, TI21, TI22, TI23, TI31, TI32, TI33 = cache.tab
386386
@unpack c1, c2, γ, α, β, e1, e2, e3 = cache.tab
387387
@unpack κ, cont1, cont2, cont3 = cache
@@ -396,8 +396,8 @@ alg = unwrap_alg(integrator, true)
396396
mass_matrix = integrator.f.mass_matrix
397397

398398
# precalculations
399-
c1m1 = c1 - 1
400-
c2m1 = c2 - 1
399+
c1m1 = (c1 - 1)*dt
400+
c2m1 = (c2 - 1)*dt
401401
c1mc2 = c1 - c2
402402
γdt, αdt, βdt = γ / dt, α / dt, β / dt
403403
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@@ -418,16 +418,16 @@ if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
418418
@.. broadcast=false w1=uzero
419419
@.. broadcast=false w2=uzero
420420
@.. broadcast=false w3=uzero
421-
@.. broadcast=false cache.cont1=uzero
422-
@.. broadcast=false cache.cont2=uzero
423-
@.. broadcast=false cache.cont3=uzero
421+
@.. broadcast=false integrator.k[3] = uzero
422+
@.. broadcast=false integrator.k[4] = uzero
423+
@.. broadcast=false integrator.k[5] = uzero
424424
else
425-
c3′ = dt / cache.dtprev
425+
c3′ = dt
426426
c1′ = c1 * c3′
427427
c2′ = c2 * c3′
428-
@.. broadcast=false z1=c1′ * (cont1 + (c1′ - c2m1) * (cont2 + (c1′ - c1m1) * cont3))
429-
@.. broadcast=false z2=c2′ * (cont1 + (c2′ - c2m1) * (cont2 + (c2′ - c1m1) * cont3))
430-
@.. broadcast=false z3=c3′ * (cont1 + (c3′ - c2m1) * (cont2 + (c3′ - c1m1) * cont3))
428+
@.. broadcast=false z1=c1′ * (k[3] + (c1′ - c2m1) * (k[4] + (c1′ - c1m1) * k[5]))
429+
@.. broadcast=false z2=c2′ * (k[3] + (c2′ - c2m1) * (k[4] + (c2′ - c1m1) * k[5]))
430+
@.. broadcast=false z3=c3′ * (k[3] + (c3′ - c2m1) * (k[4] + (c3′ - c1m1) * k[5]))
431431
@.. broadcast=false w1=TI11 * z1 + TI12 * z2 + TI13 * z3
432432
@.. broadcast=false w2=TI21 * z1 + TI22 * z2 + TI23 * z3
433433
@.. broadcast=false w3=TI31 * z1 + TI32 * z2 + TI33 * z3
@@ -565,10 +565,10 @@ step_limiter!(u, integrator, p, t + dt)
565565
if integrator.EEst <= oneunit(integrator.EEst)
566566
cache.dtprev = dt
567567
if alg.extrapolant != :constant
568-
@.. broadcast=false cache.cont1=(z2 - z3) / c2m1
569-
@.. broadcast=false tmp=(z1 - z2) / c1mc2
570-
@.. broadcast=false cache.cont2=(tmp - cache.cont1) / c1m1
571-
@.. broadcast=false cache.cont3=cache.cont2 - (tmp - z1 / c1) / c2
568+
integrator.k[3] = (z2 - z3) / (dt * c2m1)
569+
@.. tmp=(z1 - z2) / (dt * c1mc2)
570+
integrator.k[4] = (tmp - integrator.k[3]) / c1m1
571+
integrator.k[5] = integrator.k[4] - (tmp - z1 / c1) / c2
572572
end
573573
end
574574

@@ -579,11 +579,11 @@ end
579579

580580
function _ode_addsteps!(integrator, cache::RadauIIA9ConstantCache,
581581
repeat_step = false)
582-
@unpack t, dt, uprev, u, f, p = integrator
582+
@unpack t, dt, uprev, u, f, p, k = integrator
583583
@unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=#
584584
@unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab
585585
@unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab
586-
@unpack κ, cont1, cont2, cont3, cont4, cont5 = cache
586+
@unpack κ= cache
587587
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
588588
alg = unwrap_alg(integrator, true)
589589
@unpack maxiters = alg
@@ -618,41 +618,46 @@ integrator.stats.nw += 1
618618
# TODO better initial guess
619619
if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
620620
cache.dtprev = one(cache.dtprev)
621-
z1 = w1 = map(zero, u)
622-
z2 = w2 = map(zero, u)
623-
z3 = w3 = map(zero, u)
624-
z4 = w4 = map(zero, u)
625-
z5 = w5 = map(zero, u)
626-
cache.cont1 = map(zero, u)
627-
cache.cont2 = map(zero, u)
628-
cache.cont3 = map(zero, u)
629-
cache.cont4 = map(zero, u)
630-
cache.cont5 = map(zero, u)
621+
z1 = map(zero, u)
622+
z2 = map(zero, u)
623+
z3 = map(zero, u)
624+
z4 = map(zero, u)
625+
z5 = map(zero, u)
626+
w1 = map(zero, u)
627+
w2 = map(zero, u)
628+
w3 = map(zero, u)
629+
w4 = map(zero, u)
630+
w5 = map(zero, u)
631+
integrator.k[3] = map(zero, u)
632+
integrator.k[4] = map(zero, u)
633+
integrator.k[5] = map(zero, u)
634+
integrator.k[6] = map(zero, u)
635+
integrator.k[7] = map(zero, u)
631636
else
632637
c5′ = dt / cache.dtprev
633638
c1′ = c1 * c5′
634639
c2′ = c2 * c5′
635640
c3′ = c3 * c5′
636641
c4′ = c4 * c5′
637-
z1 = @.. c1′ * (cont1 +
638-
(c1′-c4m1) * (cont2 +
639-
(c1′ - c3m1) * (cont3 +
640-
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
641-
z2 = @.. c2′ * (cont1 +
642-
(c2′-c4m1) * (cont2 +
643-
(c2′ - c3m1) * (cont3 +
644-
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
645-
z3 = @.. c3′ * (cont1 +
646-
(c3′-c4m1) * (cont2 +
647-
(c3′ - c3m1) * (cont3 +
648-
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
649-
z4 = @.. c4′ * (cont1 +
650-
(c4′-c4m1) * (cont2 +
651-
(c4′ - c3m1) * (cont3 +
652-
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
653-
z5 = @.. c5′ * (cont1 +
654-
(c5′-c4m1) * (cont2 +
655-
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
642+
z1 = @.. c1′ * (k[3] +
643+
(c1′-c4m1) * (k[4] +
644+
(c1′ - c3m1) * (k[5] +
645+
(c1′ - c2m1) * (k[6] + (c1′ - c1m1) * k[7]))))
646+
z2 = @.. c2′ * (k[3] +
647+
(c2′-c4m1) * (k[4] +
648+
(c2′ - c3m1) * (k[5] +
649+
(c2′ - c2m1) * (k[6] + (c2′ - c1m1) * k[7]))))
650+
z3 = @.. c3′ * (k[3] +
651+
(c3′-c4m1) * (k[4] +
652+
(c3′ - c3m1) * (k[5] +
653+
(c3′ - c2m1) * (k[6] + (c3′ - c1m1) * k[7]))))
654+
z4 = @.. c4′ * (k[3] +
655+
(c4′-c4m1) * (k[4] +
656+
(c4′ - c3m1) * (k[5] +
657+
(c4′ - c2m1) * (k[6] + (c4′ - c1m1) * k[7]))))
658+
z5 = @.. c5′ * (k[3] +
659+
(c5′-c4m1) * (k[4] +
660+
(c5′ - c3m1) * (k[5] + (c5′ - c2m1) * (k[6] + (c5′ - c1m1) * k[7]))))
656661
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
657662
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
658663
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
@@ -762,30 +767,30 @@ if fail_convergence
762767
integrator.stats.nnonlinconvfail += 1
763768
return
764769
end
765-
cache.ηold = η
766-
cache.iter = iter
770+
#cache.ηold = η
771+
#cache.iter = iter
767772

768773
u = @.. broadcast=false uprev+z5
769774

770775

771776
if integrator.EEst <= oneunit(integrator.EEst)
772-
cache.dtprev = dt
777+
#cache.dtprev = dt
773778
if alg.extrapolant != :constant
774-
cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1]
779+
integrator.k[3] = (z4 - z5) / c4m1 # first derivative on [c4, 1]
775780
tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4]
776-
cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1]
781+
integrator.k[4] = (tmp1 - integrator.k[3]) / c3m1 # second derivative on [c3, 1]
777782
tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3]
778783
tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4]
779-
cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
784+
integrator.k[5] = (tmp3 - integrator.k[4]) / c2m1 # third derivative on [c2, 1]
780785
tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2]
781786
tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
782787
tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
783-
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
788+
integrator.k[6] = (tmp6 - integrator.k[5]) / c1m1 #fourth derivative on [c1, 1]
784789
tmp7 = @.. z1 / c1 #first derivative on [0, c1]
785790
tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2]
786791
tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3]
787792
tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
788-
cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1]
793+
integrator.k[7] = integrator.k[6] - tmp10 #fifth derivative on [0,1]
789794
end
790795
end
791796

@@ -798,11 +803,11 @@ return
798803
end
799804

800805
function _ode_addsteps!(integrator, cache::RadauIIA9Cache, repeat_step = false)
801-
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
806+
@unpack t, dt, uprev, u, f, p, fsallast, fslafirst, k = integrator
802807
@unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=#
803808
@unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab
804809
@unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab
805-
@unpack κ, cont1, cont2, cont3, cont4, cont5 = cache
810+
@unpack κ= cache
806811
@unpack z1, z2, z3, z4, z5, w1, w2, w3, w4, w5 = cache
807812
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
808813
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
@@ -849,36 +854,36 @@ if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
849854
@.. broadcast=false w3=uzero
850855
@.. broadcast=false w4=uzero
851856
@.. broadcast=false w5=uzero
852-
@.. broadcast=false cache.cont1=uzero
853-
@.. broadcast=false cache.cont2=uzero
854-
@.. broadcast=false cache.cont3=uzero
855-
@.. broadcast=false cache.cont4=uzero
856-
@.. broadcast=false cache.cont5=uzero
857+
@.. broadcast=false integrator.k[3]=uzero
858+
@.. broadcast=false integrator.k[4]=uzero
859+
@.. broadcast=false integrator.k[5]=uzero
860+
@.. broadcast=false integrator.k[6]=uzero
861+
@.. broadcast=false integrator.k[7]=uzero
857862
else
858863
c5′ = dt / cache.dtprev
859864
c1′ = c1 * c5′
860865
c2′ = c2 * c5′
861866
c3′ = c3 * c5′
862867
c4′ = c4 * c5′
863-
@.. z1 = c1′ * (cont1 +
864-
(c1′-c4m1) * (cont2 +
865-
(c1′ - c3m1) * (cont3 +
866-
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
867-
@.. z2 = c2′ * (cont1 +
868-
(c2′-c4m1) * (cont2 +
869-
(c2′ - c3m1) * (cont3 +
870-
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
871-
@.. z3 = c3′ * (cont1 +
872-
(c3′-c4m1) * (cont2 +
873-
(c3′ - c3m1) * (cont3 +
874-
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
875-
@.. z4 = c4′ * (cont1 +
876-
(c4′-c4m1) * (cont2 +
877-
(c4′ - c3m1) * (cont3 +
878-
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
879-
@.. z5 = c5′ * (cont1 +
880-
(c5′-c4m1) * (cont2 +
881-
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
868+
@.. z1 = c1′ * (k[3] +
869+
(c1′-c4m1) * (k[4] +
870+
(c1′ - c3m1) * (k[5] +
871+
(c1′ - c2m1) * (k[6] + (c1′ - c1m1) * k[7]))))
872+
@.. z2 = c2′ * (k[3] +
873+
(c2′-c4m1) * (k[4] +
874+
(c2′ - c3m1) * (k[5] +
875+
(c2′ - c2m1) * (k[6] + (c2′ - c1m1) * k[7]))))
876+
@.. z3 = c3′ * (k[3] +
877+
(c3′-c4m1) * (k[4] +
878+
(c3′ - c3m1) * (k[5] +
879+
(c3′ - c2m1) * (k[6] + (c3′ - c1m1) * k[7]))))
880+
@.. z4 = c4′ * (k[3] +
881+
(c4′-c4m1) * (k[4] +
882+
(c4′ - c3m1) * (k[5] +
883+
(c4′ - c2m1) * (k[6] + (c4′ - c1m1) * k[7]))))
884+
@.. z5 = c5′ * (k[3] +
885+
(c5′-c4m1) * (k[4] +
886+
(c5′ - c3m1) * (k[5] + (c5′ - c2m1) * (k[6] + (c5′ - c1m1) * k[7]))))
882887
@.. w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
883888
@.. w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
884889
@.. w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
@@ -1060,25 +1065,22 @@ cache.iter = iter
10601065

10611066
step_limiter!(u, integrator, p, t + dt)
10621067

1063-
if integrator.EEst <= oneunit(integrator.EEst)
1064-
cache.dtprev = dt
1065-
if alg.extrapolant != :constant
1066-
@.. cache.cont1 = (z4 - z5) / c4m1 # first derivative on [c4, 1]
1067-
@.. tmp = (z3 - z4) / c3mc4 # first derivative on [c3, c4]
1068-
@.. cache.cont2 = (tmp - cache.cont1) / c3m1 # second derivative on [c3, 1]
1069-
@.. tmp2 = (z2 - z3) / c2mc3 # first derivative on [c2, c3]
1070-
@.. tmp3 = (tmp2 - tmp) / c2mc4 # second derivative on [c2, c4]
1071-
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
1072-
@.. tmp4 = (z1 - z2) / c1mc2 # first derivative on [c1, c2]
1073-
@.. tmp5 = (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
1074-
@.. tmp6 = (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
1075-
@.. cache.cont4 = (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
1076-
@.. tmp7 = z1 / c1 #first derivative on [0, c1]
1077-
@.. tmp8 = (tmp4 - tmp7) / c2 #second derivative on [0, c2]
1078-
@.. tmp9 = (tmp5 - tmp8) / c3 #third derivative on [0, c3]
1079-
@.. tmp10 = (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
1080-
@.. cache.cont5 = cache.cont4 - tmp10 #fifth derivative on [0,1]
1081-
end
1068+
if alg.extrapolant != :constant
1069+
integrator.k[3] = (z4 - z5) / c4m1 # first derivative on [c4, 1]
1070+
@.. tmp = (z3 - z4) / c3mc4 # first derivative on [c3, c4]
1071+
integrator.k[4] = (tmp - integrator.k[3]) / c3m1 # second derivative on [c3, 1]
1072+
@.. tmp2 = (z2 - z3) / c2mc3 # first derivative on [c2, c3]
1073+
@.. tmp3 = (tmp2 - tmp) / c2mc4 # second derivative on [c2, c4]
1074+
integrator.k[5] = (tmp3 - integrator.k[4]) / c2m1 # third derivative on [c2, 1]
1075+
@.. tmp4 = (z1 - z2) / c1mc2 # first derivative on [c1, c2]
1076+
@.. tmp5 = (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
1077+
@.. tmp6 = (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
1078+
integrator.k[6] = (tmp6 - integrator.k[5]) / c1m1 #fourth derivative on [c1, 1]
1079+
@.. tmp7 = z1 / c1 #first derivative on [0, c1]
1080+
@.. tmp8 = (tmp4 - tmp7) / c2 #second derivative on [0, c2]
1081+
@.. tmp9 = (tmp5 - tmp8) / c3 #third derivative on [0, c3]
1082+
@.. tmp10 = (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
1083+
integrator.k[7] = integrator.k[6] - tmp10 #fifth derivative on [0,1]
10821084
end
10831085

10841086
f(fsallast, u, p, t + dt)
@@ -1297,11 +1299,11 @@ function _ode_addstep!(integrator, cache::AdaptiveRadauConstantCache, repeat_ste
12971299
end
12981300

12991301
function _ode_addsteps!(integrator, cache::AdaptiveRadauCache, repeat_step = false)
1300-
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
1302+
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst, k = integrator
13011303
@unpack num_stages, tabs, index = cache
13021304
tab = tabs[index]
13031305
@unpack T, TI, γ, α, β, c, e = tab
1304-
@unpack κ, cont, derivatives, z, w, c_prime, αdt, βdt= cache
1306+
@unpack κ, derivatives, z, w, c_prime, αdt, βdt= cache
13051307
@unpack dw1, ubuff, dw2, cubuff, dw = cache
13061308
@unpack ks, k, fw, J, W1, W2 = cache
13071309
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
@@ -1358,19 +1360,19 @@ function _ode_addsteps!(integrator, cache::AdaptiveRadauCache, repeat_step = fal
13581360
for i in 1 : num_stages
13591361
@.. z[i] = map(zero, u)
13601362
@.. w[i] = map(zero, u)
1361-
@.. cache.cont[i] = map(zero, u)
1363+
integrator.k[i + 2] = map(zero, u)
13621364
end
13631365
else
1364-
c_prime[num_stages] = dt / cache.dtprev
1366+
c_prime[num_stages] = dt
13651367
for i in 1 : num_stages - 1
13661368
c_prime[i] = c[i] * c_prime[num_stages]
13671369
end
13681370
for i in 1 : num_stages # collocation polynomial
1369-
@.. z[i] = cont[num_stages] * (c_prime[i] - c[1] + 1) + cont[num_stages - 1]
1371+
@.. z[i] = k[num_stages + 2] * (c_prime[i] - c[1] + 1) + k[num_stages + 1]
13701372
j = num_stages - 2
13711373
while j > 0
13721374
@.. z[i] *= (c_prime[i] - c[num_stages - j] + 1)
1373-
@.. z[i] += cont[j]
1375+
@.. z[i] += k[j + 2]
13741376
j = j - 1
13751377
end
13761378
@.. z[i] *= c_prime[i]
@@ -1541,7 +1543,7 @@ function _ode_addsteps!(integrator, cache::AdaptiveRadauCache, repeat_step = fal
15411543
end
15421544
end
15431545
for i in 1 : num_stages
1544-
@.. cache.cont[i] = derivatives[i, num_stages]
1546+
integrator.k[i + 2] = derivatives[i, num_stages]
15451547
end
15461548
end
15471549
end

0 commit comments

Comments
 (0)