-
-
Notifications
You must be signed in to change notification settings - Fork 245
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Following up #4275 (comment) with a slightly more complex example that needs initialization:
using ModelingToolkit, SymbolicIndexingInterface, OrdinaryDiffEq, Zygote, SciMLSensitivity, ForwardDiff, FiniteDiff
@independent_variables t
D = Differential(t)
@variables a(t)
@parameters Ωr0 Ωm0 ΩΛ0
eqs = [D(a) ~ √(Ωr0/a^4 + Ωm0/a^3 + ΩΛ0) * a^2]
ics = [a => √(Ωr0) * t]
@mtkbuild M = System(eqs, t; initial_conditions = ics)
prob = ODEProblem(M, [M.Ωr0 => NaN, M.Ωm0 => NaN, M.ΩΛ0 => NaN], (1e-3, 2.0))
setter = setsym_oop(prob, [M.Ωr0, M.Ωm0, M.ΩΛ0])
function a_final(x)
Ωr0, Ωm0 = x[1], x[2]
ΩΛ0 = 1 - Ωr0 - Ωm0
u0, p = setter(prob, [Ωr0, Ωm0, ΩΛ0])
newprob = remake(prob; u0, p)
sol = solve(newprob, Tsit5(); save_everystep = false, sensealg = SciMLSensitivity.GaussAdjoint())
return sol.u[end][1]
end
x0 = [1e-5, 0.3]
a_final(x0) # works
FiniteDiff.finite_difference_gradient(a_final, x0) # works
ForwardDiff.gradient(a_final, x0) # works
Zygote.gradient(a_final, x0) # failsERROR: UndefVarError: `j` not defined
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0 [inlined]
[2] (::Zygote.Pullback{Tuple{Core.IntrinsicFunction, Float64}, Tuple{Core.IntrinsicFunction}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:100
[3] sqrt
@ ~/.julia/packages/NaNMath/zoR8O/src/NaNMath.jl:60 [inlined]
[4] (::Zygote.Pullback{Tuple{typeof(NaNMath.sqrt), Float64}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[5] generated_callfunc
@ ~/.julia/packages/SymbolicUtils/bL56Y/src/code.jl:713 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[7] #305
@ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
[8] #2189#back
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
[9] RuntimeGeneratedFunction
@ ~/.julia/packages/RuntimeGeneratedFunctions/9kbBw/src/RuntimeGeneratedFunctions.jl:187 [inlined]
[10] #305
@ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
[11] #2189#back
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
[12] _generated_call
@ ~/.julia/packages/ModelingToolkitBase/w4DDk/src/systems/codegen_utils.jl:0 [inlined]
[13] #305
@ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
[14] #2189#back
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
[15] GeneratedFunctionWrapper
@ ~/.julia/packages/ModelingToolkitBase/w4DDk/src/systems/codegen_utils.jl:417 [inlined]
[16] TimeIndependentObservedFunction
@ ~/.julia/packages/SymbolicIndexingInterface/D9aQf/src/state_indexing.jl:142 [inlined]
[17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[18] AbstractStateGetIndexer
@ ~/.julia/packages/SymbolicIndexingInterface/D9aQf/src/value_provider_interface.jl:166 [inlined]
[19] #305
@ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
[20] (::Zygote.var"#2189#back#307"{Zygote.var"#305#306"{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[21] call_composed
@ ./operators.jl:1045 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Any})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[23] call_composed (repeats 2 times)
@ ./operators.jl:1044 [inlined]
[24] #_#103
@ ./operators.jl:1041 [inlined]
[25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[26] #305
@ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:214 [inlined]
[27] #2189#back
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
[28] ComposedFunction
@ ./operators.jl:1041 [inlined]
[29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#721#724"{…}})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[30] #get_initial_values#9
@ ~/.julia/packages/SciMLBase/YgvDS/src/initialization.jl:322 [inlined]
[31] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{ChainRulesCore.InplaceableThunk{…}, @NamedTuple{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[32] get_initial_values
@ ~/.julia/packages/SciMLBase/YgvDS/src/initialization.jl:267 [inlined]
[33] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{ChainRulesCore.InplaceableThunk{…}, @NamedTuple{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[34] #292
@ ~/.julia/packages/SciMLSensitivity/3s0Zy/src/concrete_solve.jl:587 [inlined]
[35] (::Zygote.Pullback{Tuple{SciMLSensitivity.var"#292#302"{GaussAdjoint{…}, Nothing, @NamedTuple{…}}, Vector{Float64}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[36] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{SciMLSensitivity.var"#292#302"{GaussAdjoint{…}, Nothing, @NamedTuple{…}}, Vector{Float64}}, Any}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:97
[37] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::GaussAdjoint{…}, ::Vector{…}, ::MTKParameters{…}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{…}, save_idxs::Nothing, initializealg_default::SciMLBase.OverrideInit{…}, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/3s0Zy/src/concrete_solve.jl:602
[38] _concrete_solve_adjoint
@ ~/.julia/packages/SciMLSensitivity/3s0Zy/src/concrete_solve.jl:465 [inlined]
[39] #_solve_adjoint#51
@ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:1045 [inlined]
[40] _solve_adjoint
@ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:1022 [inlined]
[41] #rrule#4
@ ~/.julia/packages/DiffEqBase/2gXT1/ext/DiffEqBaseChainRulesCoreExt.jl:32 [inlined]
[42] rrule
@ ~/.julia/packages/DiffEqBase/2gXT1/ext/DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
[43] rrule
@ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
[44] chain_rrule_kw
@ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
[45] macro expansion
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0 [inlined]
[46] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::GaussAdjoint{…}, ::Vector{…}, ::MTKParameters{…}, ::Tsit5{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
[47] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[48] adjoint
@ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:211 [inlined]
[49] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[50] #solve#37
@ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:587 [inlined]
[51] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#37", ::GaussAdjoint{…}, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[52] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[53] adjoint
@ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:211 [inlined]
[54] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[55] solve
@ ~/.julia/packages/DiffEqBase/2gXT1/src/solve.jl:575 [inlined]
[56] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[57] a_final
@ ./REPL[8]:7 [inlined]
[58] _pullback(ctx::Zygote.Context{false}, f::typeof(a_final), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[59] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
[60] pullback
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
[61] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working