Skip to content

Commit 4579890

Browse files
Merge pull request #1234 from SciML/ChrisRackauckas-patch-3
Update Mooncake error throw derivatives
2 parents 50cf61d + 90c743a commit 4579890

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

test/alternative_ad_frontend.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,8 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
6161
@test_broken only(Enzyme.gradient(Reverse, senseloss(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0
6262

6363
@test mooncake_gradient(senseloss(InterpolatingAdjoint()), u0p) dup
64-
@test_throws Any mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) dup
65-
@test_throws Any mooncake_gradient(senseloss(TrackerAdjoint()), u0p) dup
66-
#@test_broken @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) ≈ dup
67-
#@test_broken @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup
64+
@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p)
65+
@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(TrackerAdjoint()), u0p)
6866
@test mooncake_gradient(senseloss(ForwardDiffSensitivity()), u0p) dup
6967
@test_broken mooncake_gradient(senseloss(ForwardSensitivity()), u0p) dup # broken because ForwardSensitivity not compatible with perturbing u0
7068

@@ -103,10 +101,8 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1]
103101
@test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0
104102

105103
@test mooncake_gradient(senseloss2(InterpolatingAdjoint()), u0p) dup
106-
@test_throws Any mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) dup
107-
@test_throws Any mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) dup
108-
#@test_broken @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) ≈ dup
109-
#@test_broken @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) ≈ dup
104+
@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p)
105+
@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(TrackerAdjoint()), u0p)
110106
@test mooncake_gradient(senseloss2(ForwardDiffSensitivity()), u0p) dup
111107
@test_broken mooncake_gradient(senseloss2(ForwardSensitivity()), u0p) dup # broken because ForwardSensitivity not compatible with perturbing u0
112108

@@ -143,10 +139,8 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1]
143139
@test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardSensitivity()), u0p)) dup
144140

145141
@test mooncake_gradient(senseloss3(InterpolatingAdjoint()), u0p) dup
146-
@test_throws Any mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) dup
147-
@test_throws Any mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) dup
148-
#@test_broken @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) ≈ dup
149-
#@test_broken @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) ≈ dup
142+
@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p)
143+
@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(TrackerAdjoint()), u0p)
150144
@test mooncake_gradient(senseloss3(ForwardDiffSensitivity()), u0p) dup
151145
@test_broken mooncake_gradient(senseloss3(ForwardSensitivity()), u0p) dup
152146

@@ -289,4 +283,4 @@ grad_rd = ReverseDiff.gradient(loss2, p)
289283
@test grad_fdgrad_fi atol=1e-2
290284
@test grad_fdgrad_zg atol=1e-4
291285
@test grad_fdgrad_rd atol=1e-4
292-
@test_broken mooncake_gradient(loss2, p) grad_rd atol=1e-4
286+
@test_broken mooncake_gradient(loss2, p) grad_rd atol=1e-4

0 commit comments

Comments
 (0)