-
-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Describe the bug 🐞
Sister to #69
Current plan is for a major rewrite to getp to handle hybrid continuous and discrete systems by @AayushSabharwal , and in the process we are going to track the behaviour of the getp function to make sure it ADs properly, considering there are known bugs with how it currently works.
This issue is to show the current situation, and track any progress. We also need to add tests related to AD for getp/u in SciMLBase/ SII to ensure the behaviour is tracked over time.
Expected behavior
Gradients produced should be correct.
Minimal Reproducible Example 👇
@parameters σ ρ β A2[1:10, 1:10]
@variables x(t) y(t) z(t) w(t) w2(t)
# @variables A[1:10, 1:10]
eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z + 2 * β,]
@mtkbuild sys = ODESystem(eqs, t)
ModelingToolkit.observed(sys)
u0 = [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]
p = [σ => 28.0,
ρ => 10.0,
β => 8 / 3,]
tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())Error & Stacktrace
julia> julia> pf = getp(sol, sys.β)
SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))
julia> gradient(sol) do sol # correct
sum(pf(sol))
end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)
julia> pf2 = getp(sol, [sys.β])
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))])
julia> gradient(sol) do sol # still correct, note single element in vector
sum(pf2(sol))
end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)
julia> pf3 = getp(sol, [sys.β, sys.β])
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2))
), SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))])
julia> gradient(sol) do sol # incorrect, should be [0.0, 2.0, 0.0]
sum(pf3(sol))
end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 3.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)
julia> pf4 = getp(sol, [sys.β, sys.ρ])
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}
}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2))), SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 1)))])
julia> gradient(sol) do sol # incorrect, should be [1.0, 1.0, 0.0]
sum(pf4(sol))
end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([2.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)The gradients returned when passed a vector of parameters is incorrect.
Environment (please complete the following information):
- Output of
using Pkg; Pkg.status()
(SciMLSensitivity) pkg> st [25/1807]
Project SciMLSensitivity v7.57.0
Status `~/arpa/jsmo/clone/SciMLSensitivity.jl/Project.toml`
⌅ [47edcb42] ADTypes v0.2.7
[79e6a3ab] Adapt v4.0.4
[4fba245c] ArrayInterface v7.10.0
[082447d4] ChainRules v1.66.0 `https://github.com/JuliaDiff/ChainRules.jl.git#main`
[d360d2e6] ChainRulesCore v1.23.0
[2b5f629d] DiffEqBase v6.149.1 `https://github.com/DhairyaLGandhi/DiffEqBase.jl.git#dg/kw`
[459566f4] DiffEqCallbacks v3.6.2
[77a26b50] DiffEqNoiseProcess v5.21.0
[31c24e10] Distributions v0.25.108
[da5c29d0] EllipsisNotation v1.8.0
[7da242da] Enzyme v0.12.6
[6a86dc24] FiniteDiff v2.23.1
[f6369f11] ForwardDiff v0.10.36
[f62d2435] FunctionProperties v0.1.2
[77dc65aa] FunctionWrappersWrappers v0.1.3
[d9f16b24] Functors v0.4.10
[46192b85] GPUArraysCore v0.1.6
[7ed4a6bd] LinearSolve v2.30.0
[961ee093] ModelingToolkit v9.13.0
[1dea7af3] OrdinaryDiffEq v6.76.0
[d96e819e] Parameters v0.12.3
[d236fae5] PreallocationTools v0.4.21
[1fd47b50] QuadGK v2.9.4
[e6cf234a] RandomNumbers v1.5.3
[731186ca] RecursiveArrayTools v3.18.1 `../RecursiveArrayTools.jl`
[189a3867] Reexport v1.2.2
[37e2e3b7] ReverseDiff v1.15.3
[0bca4576] SciMLBase v2.36.1 `https://github.com/DhairyaLGandhi/SciMLBase.jl#dg/obsfn`
[c0aeaf25] SciMLOperators v0.3.8
[53ae85a6] SciMLStructures v1.2.0
⌃ [47a9eef4] SparseDiffTools v2.18.0
[90137ffa] StaticArrays v1.9.3
[1e83bf80] StaticArraysCore v1.4.2
[789caeaf] StochasticDiffEq v6.65.1
[2efcf032] SymbolicIndexingInterface v0.3.21
[9f7883ad] Tracker v0.2.34
[781d530d] TruncatedStacktraces v1.4.0
[e88e6eb3] Zygote v0.6.70
[37e2e46d] LinearAlgebra
[d6f4376e] Markdown
[9a3f8284] Random
[10745b16] Statistics v1.10.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see
why use `status --outdated`- Output of
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
- Output of
versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 64 × AMD EPYC 7513 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 1 on 64 virtual coresAdditional context
Add any other context about the problem here.