Skip to content

Commit fbd1649

Browse files
Merge pull request #1273 from maxesit/gk-concrete
Add GK Adjoint to `concrete_solve.jl`, and the respective tests
2 parents 5d955d3 + b52879c commit fbd1649

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

src/concrete_solve.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,14 +369,15 @@ function DiffEqBase._concrete_solve_adjoint(
369369
sensealg::Union{BacksolveAdjoint,
370370
QuadratureAdjoint,
371371
InterpolatingAdjoint,
372-
GaussAdjoint},
372+
GaussAdjoint,
373+
GaussKronrodAdjoint},
373374
u0, p, originator::SciMLBase.ADOriginator,
374375
args...; save_start = true, save_end = true,
375376
saveat = eltype(prob.tspan)[],
376377
save_idxs = nothing,
377378
initializealg_default = SciMLBase.OverrideInit(; abstol = 1e-6, reltol = 1e-3),
378379
kwargs...)
379-
if !(sensealg isa GaussAdjoint) &&
380+
if !((sensealg isa GaussAdjoint)||(sensealg isa GaussKronrodAdjoint)) &&
380381
!(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
381382
(p isa AbstractArray && !Base.isconcretetype(eltype(p)))
382383
throw(AdjointSensitivityParameterCompatibilityError())
@@ -386,7 +387,7 @@ function DiffEqBase._concrete_solve_adjoint(
386387
tunables, repack = p, identity
387388
elseif isscimlstructure(p)
388389
tunables, repack, aliases = canonicalize(Tunable(), p)
389-
elseif sensealg isa Union{QuadratureAdjoint, GaussAdjoint}
390+
elseif sensealg isa Union{QuadratureAdjoint, GaussAdjoint, GaussKronrodAdjoint}
390391
tunables, repack = Functors.functor(p)
391392
else
392393
throw(SciMLStructuresCompatibilityError())

test/concrete_solve_derivatives.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,14 @@ dp13 = Zygote.gradient(
268268
saveat = 0.1, save_idxs = 1:1,
269269
sensealg = GaussAdjoint())),
270270
u0, p)
271-
271+
du014,
272+
dp14 = Zygote.gradient(
273+
(u0,
274+
p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p,
275+
abstol = 1e-14, reltol = 1e-14,
276+
saveat = 0.1, save_idxs = 1:1,
277+
sensealg = GaussKronrodAdjoint())),
278+
u0, p)
272279
@test ū02du05 rtol=1e-12
273280
@test ū02du06 rtol=1e-12
274281
@test ū02du07 rtol=1e-12
@@ -278,6 +285,7 @@ dp13 = Zygote.gradient(
278285
#@test ū02 ≈ du011 rtol=1e-12
279286
@test ū02du012 rtol=1e-12
280287
@test ū02du013 rtol=1e-12
288+
@test ū02du014 rtol=1e-12
281289
@test adj2dp5 rtol=1e-12
282290
@test adj2dp6 rtol=1e-12
283291
@test adj2dp7 rtol=1e-12
@@ -287,6 +295,7 @@ dp13 = Zygote.gradient(
287295
#@test adj2 ≈ dp11 rtol=1e-12
288296
@test adj2dp12 rtol=1e-12
289297
@test adj2dp13 rtol=1e-12
298+
@test adj2dp14 rtol=1e-12
290299

291300
###
292301
### Only End
@@ -415,6 +424,15 @@ dp9 = Zygote.gradient(
415424
sensealg = GaussAdjoint())),
416425
u0,
417426
p)
427+
du010,
428+
dp10 = Zygote.gradient(
429+
(u0,
430+
p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p,
431+
abstol = 1e-14, reltol = 1e-14,
432+
saveat = 0.1,
433+
sensealg = GaussKronrodAdjoint())),
434+
u0,
435+
p)
418436

419437
@test ū0du01 rtol=1e-12
420438
@test ū0du02 rtol=1e-12
@@ -425,6 +443,7 @@ dp9 = Zygote.gradient(
425443
@test_broken ū0du07 rtol=1e-12
426444
@test ū0du08 rtol=1e-12
427445
@test ū0du09 rtol=1e-12
446+
@test ū0du010 rtol=1e-12
428447
@test adjdp1' rtol=1e-12
429448
@test adjdp2' rtol=1e-12
430449
@test adjdp3' rtol=1e-12
@@ -434,6 +453,7 @@ dp9 = Zygote.gradient(
434453
@test_broken adjdp7' rtol=1e-12
435454
@test adjdp8' rtol=1e-12
436455
@test adjdp9' rtol=1e-12
456+
@test adjdp10' rtol=1e-12
437457

438458
###
439459
### forward
@@ -571,6 +591,15 @@ dp15 = Zygote.gradient(
571591
sensealg = GaussAdjoint())),
572592
u0,
573593
p)
594+
du016,
595+
dp16 = Zygote.gradient(
596+
(u0,
597+
p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p,
598+
abstol = 1e-14, reltol = 1e-14,
599+
save_idxs = 1, saveat = 0.1,
600+
sensealg = GaussKronrodAdjoint())),
601+
u0,
602+
p)
574603

575604
@test ū02du05 rtol=1e-12
576605
@test ū02du06 rtol=1e-12
@@ -583,6 +612,7 @@ dp15 = Zygote.gradient(
583612
@test ū02du013 rtol=1e-12
584613
@test ū02du014 rtol=1e-12
585614
@test ū02du015 rtol=1e-12
615+
@test ū02du016 rtol=1e-12
586616
@test adj2dp5 rtol=1e-12
587617
@test adj2dp6 rtol=1e-12
588618
@test adj2dp7 rtol=1e-12
@@ -594,6 +624,7 @@ dp15 = Zygote.gradient(
594624
@test adj2dp13 rtol=1e-12
595625
@test adj2dp14 rtol=1e-12
596626
@test adj2dp15 rtol=1e-12
627+
@test adj2dp16 rtol=1e-12
597628

598629
# Handle VecOfArray Derivatives
599630
dp1 = Zygote.gradient(

0 commit comments

Comments
 (0)