Skip to content

Commit b52879c

Browse files
committed
chore: add tests for GK adjoint to concrete_solve_derivatives.jl
1 parent 76aceef commit b52879c

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

test/concrete_solve_derivatives.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,14 @@ dp13 = Zygote.gradient(
268268
saveat = 0.1, save_idxs = 1:1,
269269
sensealg = GaussAdjoint())),
270270
u0, p)
271-
271+
du014,
272+
dp14 = Zygote.gradient(
273+
(u0,
274+
p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p,
275+
abstol = 1e-14, reltol = 1e-14,
276+
saveat = 0.1, save_idxs = 1:1,
277+
sensealg = GaussKronrodAdjoint())),
278+
u0, p)
272279
@test ū02du05 rtol=1e-12
273280
@test ū02du06 rtol=1e-12
274281
@test ū02du07 rtol=1e-12
@@ -278,6 +285,7 @@ dp13 = Zygote.gradient(
278285
#@test ū02 ≈ du011 rtol=1e-12
279286
@test ū02du012 rtol=1e-12
280287
@test ū02du013 rtol=1e-12
288+
@test ū02du014 rtol=1e-12
281289
@test adj2dp5 rtol=1e-12
282290
@test adj2dp6 rtol=1e-12
283291
@test adj2dp7 rtol=1e-12
@@ -287,6 +295,7 @@ dp13 = Zygote.gradient(
287295
#@test adj2 ≈ dp11 rtol=1e-12
288296
@test adj2dp12 rtol=1e-12
289297
@test adj2dp13 rtol=1e-12
298+
@test adj2dp14 rtol=1e-12
290299

291300
###
292301
### Only End
@@ -415,6 +424,15 @@ dp9 = Zygote.gradient(
415424
sensealg = GaussAdjoint())),
416425
u0,
417426
p)
427+
du010,
428+
dp10 = Zygote.gradient(
429+
(u0,
430+
p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p,
431+
abstol = 1e-14, reltol = 1e-14,
432+
saveat = 0.1,
433+
sensealg = GaussKronrodAdjoint())),
434+
u0,
435+
p)
418436

419437
@test ū0du01 rtol=1e-12
420438
@test ū0du02 rtol=1e-12
@@ -425,6 +443,7 @@ dp9 = Zygote.gradient(
425443
@test_broken ū0du07 rtol=1e-12
426444
@test ū0du08 rtol=1e-12
427445
@test ū0du09 rtol=1e-12
446+
@test ū0du010 rtol=1e-12
428447
@test adjdp1' rtol=1e-12
429448
@test adjdp2' rtol=1e-12
430449
@test adjdp3' rtol=1e-12
@@ -434,6 +453,7 @@ dp9 = Zygote.gradient(
434453
@test_broken adjdp7' rtol=1e-12
435454
@test adjdp8' rtol=1e-12
436455
@test adjdp9' rtol=1e-12
456+
@test adjdp10' rtol=1e-12
437457

438458
###
439459
### forward
@@ -571,6 +591,15 @@ dp15 = Zygote.gradient(
571591
sensealg = GaussAdjoint())),
572592
u0,
573593
p)
594+
du016,
595+
dp16 = Zygote.gradient(
596+
(u0,
597+
p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p,
598+
abstol = 1e-14, reltol = 1e-14,
599+
save_idxs = 1, saveat = 0.1,
600+
sensealg = GaussKronrodAdjoint())),
601+
u0,
602+
p)
574603

575604
@test ū02du05 rtol=1e-12
576605
@test ū02du06 rtol=1e-12
@@ -583,6 +612,7 @@ dp15 = Zygote.gradient(
583612
@test ū02du013 rtol=1e-12
584613
@test ū02du014 rtol=1e-12
585614
@test ū02du015 rtol=1e-12
615+
@test ū02du016 rtol=1e-12
586616
@test adj2dp5 rtol=1e-12
587617
@test adj2dp6 rtol=1e-12
588618
@test adj2dp7 rtol=1e-12
@@ -594,6 +624,7 @@ dp15 = Zygote.gradient(
594624
@test adj2dp13 rtol=1e-12
595625
@test adj2dp14 rtol=1e-12
596626
@test adj2dp15 rtol=1e-12
627+
@test adj2dp16 rtol=1e-12
597628

598629
# Handle VecOfArray Derivatives
599630
dp1 = Zygote.gradient(

0 commit comments

Comments
 (0)