Skip to content

Conversation

@AstitvaAggarwal
Copy link
Member

@AstitvaAggarwal AstitvaAggarwal commented Jul 9, 2025

Closes SciML/SciMLSensitivity.jl#1230 and chalk-lab/Mooncake.jl#587.

This PR handles cases where:
DiffEqBase._concrete_solve_adjoint must error out when using ReverseDiffAdjoint/TrackerAdjoint while differentiating via Mooncake. This error was already handled by SciMLSensitivity but was not getting hit.
Therefore calling Mooncake.@mooncake_overlay for DiffEqBase.set_mooncakeoriginator_if_mooncake is required otherwise the Mooncake.DerivedRule which contains primal typechecks fails: as Tracker for example adds tags such as Tracker.TrackerReal{Float64} around Float64's to the forward pass primals.

The previous PR handled:
any other case (eg: not using ReverseDiffAdjoint/TrackerAdjoint) when it is required to use a Mooncake.rrule!! for DiffEqBase.set_mooncakeoriginator_if_mooncake in a Mooncake.DerivedRule.

julia> using OrdinaryDiffEq, SciMLSensitivity, Mooncake, Test
       mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]

       odef(du, u, p, t) = du .= u .* p
       const prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])

       struct senseloss{T}
           sense::T
       end
       function (f::senseloss)(u0p)
           sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12,
               reltol = 1e-12, saveat = 0.1, sensealg = f.sense))
       end
       function loss(u0p)
           sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12, reltol = 1e-12,
               saveat = 0.1))
       end
       u0p = [2.0, 3.0]
WARNING: redefinition of constant Main.prob. This may fail, cause incorrect answers, or produce other errors.

julia> @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup
Test Passed
      Thrown: SciMLSensitivity.MooncakeTrackedRealError

@ChrisRackauckas ChrisRackauckas merged commit 367b691 into SciML:master Jul 9, 2025
38 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mooncake gives the wrong aggregator and thus does not give contextualized error messages

2 participants