-
-
Notifications
You must be signed in to change notification settings - Fork 79
Closed
Labels
Description
A case from the Catalyst docs which has been fine until earlier this week, so likely due to a recent update. I tracked it down (I think) as AD-related, so I think it might belong here (but not sure).
Basically, when I try to fit the parameters for this ODE while depending on AD, I get an error in the solve call. Here is a version where I have tried to condense the example to something shorter
# Fetch packages.
using Catalyst
using OrdinaryDiffEqDefault
using Optimization
using OptimizationOptimisers
using SciMLSensitivity
using SymbolicIndexingInterface
# Create model.
rs = @reaction_network begin
(p,d), 0 <--> X
end
# Create `ODEProblem`.
u0 = [:X => 1.0]
p_true = [:p => 1.0, :d => 0.2]
oprob_true = ODEProblem(rs, u0, 10.0, p_true)
# Simulate training data.
sol_true = solve(oprob_true)
t_measured = 1:10
X_measured = sol_true(t_measured; idxs = :X)
X_measured = [(0.9 + 0.2 * rand()) * x for x in X_measured]
# Create cost function.
set_p = SymbolicIndexingInterface.setp_oop(oprob_true, [:p, :d])
function loss(p, (set_p, prob, t_measured, X_measured))
p = set_p(prob, p)
newprob = remake(prob; p)
sol = solve(newprob; verbose = false, maxiters = 10000, saveat = t_measured)
X_sim = Array(sol)[1,:]
return sum(abs2, X_sim .- X_measured)
end
# Attempt to fit parameters.
of = OptimizationFunction(loss, Optimization.AutoZygote())
pinit = [0.5, 0.5]
optprob = OptimizationProblem(of, pinit, (set_p, oprob_true, t_measured, X_measured))
sol = solve(optprob, ADAM(0.1); maxiters = 100)ERROR: MethodError: no method matching +(::@NamedTuple{…}, ::Base.RefValue{…})
The function `+` exists, but no method is defined for this combination of argument types.
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:596
+(::ChainRulesCore.NoTangent, ::Any)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_arithmetic.jl:59
+(::Any, ::ChainRulesCore.NoTangent)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_arithmetic.jl:60
...
Stacktrace:
[1] accum(x::@NamedTuple{…}, y::Base.RefValue{…})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:9
[2] #get_initial_values#9
@ ~/.julia/packages/SciMLBase/iHgIu/src/initialization.jl:298 [inlined]
[3] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{ChainRulesCore.Thunk{…}, @NamedTuple{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[4] get_initial_values
@ ~/.julia/packages/SciMLBase/iHgIu/src/initialization.jl:242 [inlined]
[5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, @NamedTuple{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[6] maybe_eager_initialize_problem
@ ~/.julia/packages/SciMLBase/iHgIu/src/remake.jl:1211 [inlined]
[7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{ChainRulesCore.Thunk{…}, @NamedTuple{…}})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[8] #remake#771
@ ~/.julia/packages/SciMLBase/iHgIu/src/remake.jl:266 [inlined]
[9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[10] remake
@ ~/.julia/packages/SciMLBase/iHgIu/src/remake.jl:214 [inlined]
[11] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[12] loss
@ ./Untitled-1:29 [inlined]
[13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[14] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
[15] withgradient(::Function, ::Vector{Float64}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:219
[16] value_and_gradient
@ ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:115 [inlined]
[17] value_and_gradient!(f::Function, grad::Vector{…}, prep::DifferentiationInterface.NoGradientPrep{…}, backend::AutoZygote, x::Vector{…}, contexts::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:131
[18] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::Vector{…}, θ::Vector{…})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/UXLhR/ext/OptimizationZygoteExt.jl:53
[19] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:101 [inlined]
[20] macro expansion
@ ~/.julia/packages/Optimization/e1Lg1/src/utils.jl:32 [inlined]
[21] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:83
[22] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/iHgIu/src/solve.jl:227
[23] solve(::OptimizationProblem{…}, ::Adam{…}; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/iHgIu/src/solve.jl:129
[24] top-level scope
@ Untitled-1:39
Some type information was truncated. Use `show(err)` to see complete types.
The same error happens if I try
sol = solve(optprob, OptimizationOptimJL.BFGS(); maxiters = 100)However, if I do a differential-free method, it is all fine:
using OptimizationNLopt
of = OptimizationFunction(loss)
pinit = [0.5, 0.5]
optprob = OptimizationProblem(of, pinit, (set_p, oprob_true, t_measured, X_measured))
sol = solve(optprob, NLopt.LN_NELDERMEAD())