Skip to content

AD through initialization does not work #4342

@hersle

Description

@hersle

Following up #4275 (comment) with a slightly more complex example that needs initialization:

using ModelingToolkit, SymbolicIndexingInterface, OrdinaryDiffEq, Zygote, SciMLSensitivity, ForwardDiff, FiniteDiff

@independent_variables t
D = Differential(t)
@variables a(t)
@parameters Ωr0 Ωm0 ΩΛ0
eqs = [D(a) ~ (Ωr0/a^4 + Ωm0/a^3 + ΩΛ0) * a^2]
ics = [a => (Ωr0) * t]
@mtkbuild M = System(eqs, t; initial_conditions = ics)

prob = ODEProblem(M, [M.Ωr0 => NaN, M.Ωm0 => NaN, M.ΩΛ0 => NaN], (1e-3, 2.0))

setter = setsym_oop(prob, [M.Ωr0, M.Ωm0, M.ΩΛ0])
function a_final(x)
    Ωr0, Ωm0 = x[1], x[2]
    ΩΛ0 = 1 - Ωr0 - Ωm0
    u0, p = setter(prob, [Ωr0, Ωm0, ΩΛ0])
    newprob = remake(prob; u0, p)    
    sol = solve(newprob, Tsit5(); save_everystep = false, sensealg = SciMLSensitivity.GaussAdjoint())
    return sol.u[end][1]
end

x0 = [1e-5, 0.3]
a_final(x0) # works
FiniteDiff.finite_difference_gradient(a_final, x0) # works
ForwardDiff.gradient(a_final, x0) # works
Zygote.gradient(a_final, x0) # fails
ERROR: UndefVarError: `j` not defined
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0 [inlined]
  [2] (::Zygote.Pullback{Tuple{Core.IntrinsicFunction, Float64}, Tuple{Core.IntrinsicFunction}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:100
  [3] sqrt
    @ ~/.julia/packages/NaNMath/zoR8O/src/NaNMath.jl:60 [inlined]
  [4] (::Zygote.Pullback{Tuple{typeof(NaNMath.sqrt), Float64}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
  [5] generated_callfunc
    @ ~/.julia/packages/SymbolicUtils/bL56Y/src/code.jl:713 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
  [7] #305
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
  [8] #2189#back
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
  [9] RuntimeGeneratedFunction
    @ ~/.julia/packages/RuntimeGeneratedFunctions/9kbBw/src/RuntimeGeneratedFunctions.jl:187 [inlined]
 [10] #305
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
 [11] #2189#back
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
 [12] _generated_call
    @ ~/.julia/packages/ModelingToolkitBase/w4DDk/src/systems/codegen_utils.jl:0 [inlined]
 [13] #305
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
 [14] #2189#back
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
 [15] GeneratedFunctionWrapper
    @ ~/.julia/packages/ModelingToolkitBase/w4DDk/src/systems/codegen_utils.jl:417 [inlined]
 [16] TimeIndependentObservedFunction
    @ ~/.julia/packages/SymbolicIndexingInterface/D9aQf/src/state_indexing.jl:142 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [18] AbstractStateGetIndexer
    @ ~/.julia/packages/SymbolicIndexingInterface/D9aQf/src/value_provider_interface.jl:166 [inlined]
 [19] #305
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
 [20] (::Zygote.var"#2189#back#307"{Zygote.var"#305#306"{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
 [21] call_composed
    @ ./operators.jl:1045 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Any})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [23] call_composed (repeats 2 times)
    @ ./operators.jl:1044 [inlined]
 [24] #_#103
    @ ./operators.jl:1041 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [26] #305
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
 [27] #2189#back
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
 [28] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [30] #get_initial_values#9
    @ ~/.julia/packages/SciMLBase/YgvDS/src/initialization.jl:322 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{ChainRulesCore.InplaceableThunk{…}, @NamedTuple{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [32] get_initial_values
    @ ~/.julia/packages/SciMLBase/YgvDS/src/initialization.jl:267 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{ChainRulesCore.InplaceableThunk{…}, @NamedTuple{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [34] #292
    @ ~/.julia/packages/SciMLSensitivity/3s0Zy/src/concrete_solve.jl:587 [inlined]
 [35] (::Zygote.Pullback{Tuple{SciMLSensitivity.var"#292#302"{GaussAdjoint{…}, Nothing, @NamedTuple{…}}, Vector{Float64}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [36] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{SciMLSensitivity.var"#292#302"{GaussAdjoint{…}, Nothing, @NamedTuple{…}}, Vector{Float64}}, Any}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:97
 [37] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::GaussAdjoint{…}, ::Vector{…}, ::MTKParameters{…}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{…}, save_idxs::Nothing, initializealg_default::SciMLBase.OverrideInit{…}, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/3s0Zy/src/concrete_solve.jl:602
 [38] _concrete_solve_adjoint
    @ ~/.julia/packages/SciMLSensitivity/3s0Zy/src/concrete_solve.jl:465 [inlined]
 [39] #_solve_adjoint#51
    @ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:1045 [inlined]
 [40] _solve_adjoint
    @ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:1022 [inlined]
 [41] #rrule#4
    @ ~/.julia/packages/DiffEqBase/2gXT1/ext/DiffEqBaseChainRulesCoreExt.jl:32 [inlined]
 [42] rrule
    @ ~/.julia/packages/DiffEqBase/2gXT1/ext/DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
 [43] rrule
    @ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
 [44] chain_rrule_kw
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
 [45] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0 [inlined]
 [46] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::GaussAdjoint{…}, ::Vector{…}, ::MTKParameters{…}, ::Tsit5{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
 [47] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [48] adjoint
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:211 [inlined]
 [49] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [50] #solve#37
    @ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:587 [inlined]
 [51] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#37", ::GaussAdjoint{…}, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [52] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [53] adjoint
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:211 [inlined]
 [54] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [55] solve
    @ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:575 [inlined]
 [56] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [57] a_final
    @ ./REPL[8]:7 [inlined]
 [58] _pullback(ctx::Zygote.Context{false}, f::typeof(a_final), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [59] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
 [60] pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
 [61] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions