-
-
Notifications
You must be signed in to change notification settings - Fork 84
Description
Description
Mooncake crashes with a MooncakeRuleCompilationError (caused by stack overflow at the LLVM/JIT level) when attempting to differentiate through solve for an MTK ODEProblem that has DAE initialization (algebraic constraints with guesses).
The stack overflow occurs during Mooncake.build_rrule compilation, not during the actual gradient computation. The complex nested type structure of MTK's closure (containing ODEProblem, MTKParameters, GeneratedFunctionWrapper, OverrideInitData, System, etc.) causes Mooncake's rule compiler to overflow.
Partial fixes attempted in ModelingToolkit.jl PR SciML/ModelingToolkit.jl#4331:
tangent_type(::Type{<:Base.ImmutableDict}) = NoTangent— fixes one recursion pathtangent_type(::Type{<:System}) = NoTangent— System is self-referential (hassystems::Vector{System},parent::Union{Nothing, System})tangent_type(::Type{<:GeneratedFunctionWrapper}) = NoTangenttangent_type(::Type{<:ObservedFunctionCache}) = NoTangent@from_rrulebridges forMTKParametersconstructor,remake_buffer,SetInitialUnknowns
These fix the Julia-level StackOverflowError in tangent_type, but the LLVM-level stack overflow during rule compilation persists.
MWE
import Pkg
Pkg.activate(; temp = true)
Pkg.add(["ModelingToolkit", "OrdinaryDiffEq", "Mooncake",
"SciMLSensitivity", "SciMLStructures", "SymbolicIndexingInterface"])
using ModelingToolkit, OrdinaryDiffEq, Mooncake
using ModelingToolkit: t_nounits as t, D_nounits as D
using SciMLSensitivity
import SciMLStructures as SS
using SymbolicIndexingInterface
# Minimal DAE system: ODE + algebraic constraint requiring initialization
@parameters a b
@variables x(t) y(t)
eqs = [
D(x) ~ a * x + y,
0 ~ x^2 - b * y, # algebraic constraint
]
@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [a => -0.5, b => 2.0],
guesses = [y => 1.0])
tunables, repack, _ = SS.canonicalize(SS.Tunable(), parameter_values(prob))
# Forward solve works fine
sol = solve(prob, Rodas5P())
# Mooncake gradient fails
loss = let prob = prob, repack = repack
p -> begin
new_prob = remake(prob; p = repack(p))
sol = solve(new_prob, Rodas5P(); abstol = 1e-8, reltol = 1e-6)
sum(sol)
end
end
rule = Mooncake.build_rrule(loss, tunables) # StackOverflow here
val, (_, grad) = Mooncake.value_and_gradient!!(rule, loss, tunables)Error
Warning: detected a stack overflow; program state may be corrupted, so further execution might be unreliable.
MooncakeRuleCompilationError: an error occurred while Mooncake was compiling a rule to differentiate something.
The crash occurs at the LLVM/JIT level during code generation for the complex closure type.
Environment
- Julia 1.12
- ModelingToolkit v11
- Mooncake v0.4
- SciMLSensitivity latest
Related
- Add Mooncake and EnzymeCore extensions for AD compatibility ModelingToolkit.jl#4331 (partial tangent_type fixes)
- Update Core8 MTK tests for mutation-aware AD #1356 (Core8 test updates with
@test_broken)