Skip to content

Differentiation through MTK DAE initialization with SCCNonlinearProblem is broken for all AD backends #1358

@ChrisRackauckas-Claude

Description

@ChrisRackauckas-Claude

Description

Mooncake crashes with a MooncakeRuleCompilationError (caused by stack overflow at the LLVM/JIT level) when attempting to differentiate through solve for an MTK ODEProblem that has DAE initialization (algebraic constraints with guesses).

The stack overflow occurs during Mooncake.build_rrule compilation, not during the actual gradient computation. The complex nested type structure of MTK's closure (containing ODEProblem, MTKParameters, GeneratedFunctionWrapper, OverrideInitData, System, etc.) causes Mooncake's rule compiler to overflow.

Partial fixes attempted in ModelingToolkit.jl PR SciML/ModelingToolkit.jl#4331:

  • tangent_type(::Type{<:Base.ImmutableDict}) = NoTangent — fixes one recursion path
  • tangent_type(::Type{<:System}) = NoTangent — System is self-referential (has systems::Vector{System}, parent::Union{Nothing, System})
  • tangent_type(::Type{<:GeneratedFunctionWrapper}) = NoTangent
  • tangent_type(::Type{<:ObservedFunctionCache}) = NoTangent
  • @from_rrule bridges for MTKParameters constructor, remake_buffer, SetInitialUnknowns

These fix the Julia-level StackOverflowError in tangent_type, but the LLVM-level stack overflow during rule compilation persists.

MWE

import Pkg
Pkg.activate(; temp = true)
Pkg.add(["ModelingToolkit", "OrdinaryDiffEq", "Mooncake",
    "SciMLSensitivity", "SciMLStructures", "SymbolicIndexingInterface"])

using ModelingToolkit, OrdinaryDiffEq, Mooncake
using ModelingToolkit: t_nounits as t, D_nounits as D
using SciMLSensitivity
import SciMLStructures as SS
using SymbolicIndexingInterface

# Minimal DAE system: ODE + algebraic constraint requiring initialization
@parameters a b
@variables x(t) y(t)

eqs = [
    D(x) ~ a * x + y,
    0 ~ x^2 - b * y,  # algebraic constraint
]

@mtkbuild sys = ODESystem(eqs, t)

prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [a => -0.5, b => 2.0],
    guesses = [y => 1.0])
tunables, repack, _ = SS.canonicalize(SS.Tunable(), parameter_values(prob))

# Forward solve works fine
sol = solve(prob, Rodas5P())

# Mooncake gradient fails
loss = let prob = prob, repack = repack
    p -> begin
        new_prob = remake(prob; p = repack(p))
        sol = solve(new_prob, Rodas5P(); abstol = 1e-8, reltol = 1e-6)
        sum(sol)
    end
end

rule = Mooncake.build_rrule(loss, tunables)  # StackOverflow here
val, (_, grad) = Mooncake.value_and_gradient!!(rule, loss, tunables)

Error

Warning: detected a stack overflow; program state may be corrupted, so further execution might be unreliable.
MooncakeRuleCompilationError: an error occurred while Mooncake was compiling a rule to differentiate something.

The crash occurs at the LLVM/JIT level during code generation for the complex closure type.

Environment

  • Julia 1.12
  • ModelingToolkit v11
  • Mooncake v0.4
  • SciMLSensitivity latest

Related

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions