Skip to content

Make getp AD friendly  #77

@DhairyaLGandhi

Description

@DhairyaLGandhi

Describe the bug 🐞

Sister to #69

Current plan is for a major rewrite to getp to handle hybrid continuous and discrete systems by @AayushSabharwal , and in the process we are going to track the behaviour of the getp function to make sure it ADs properly, considering there are known bugs with how it currently works.

This issue is to show the current situation, and track any progress. We also need to add tests related to AD for getp/u in SciMLBase/ SII to ensure the behaviour is tracked over time.

Expected behavior

Gradients produced should be correct.

Minimal Reproducible Example 👇

@parameters σ ρ β A2[1:10, 1:10]
@variables x(t) y(t) z(t) w(t) w2(t)
# @variables A[1:10, 1:10]

eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x *- z) - y,
    D(z) ~ x * y - β * z,
    w ~ x + y + z + 2 * β,]

@mtkbuild sys = ODESystem(eqs, t)

ModelingToolkit.observed(sys)

u0 = [D(x) => 2.0,
    x => 1.0,
    y => 0.0,
    z => 0.0]

p ==> 28.0,
    ρ => 10.0,
    β => 8 / 3,]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())

Error & Stacktrace ⚠️

julia> julia> pf = getp(sol, sys.β)
SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))

julia> gradient(sol) do sol # correct
         sum(pf(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

julia> pf2 = getp(sol, [sys.β]) 
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))])

julia> gradient(sol) do sol # still correct, note single element in vector
         sum(pf2(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

julia> pf3 = getp(sol, [sys.β, sys.β])
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2))
), SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2)))])

julia> gradient(sol) do sol # incorrect, should be [0.0, 2.0, 0.0]
         sum(pf3(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([0.0, 3.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

julia> pf4 = getp(sol, [sys.β, sys.ρ])
SymbolicIndexingInterface.MultipleParameterGetters{Vector{SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}
}}(SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}[SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 2))), SymbolicIndexingInterface.GetParameterIndex{ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}}(ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Tuple{Int64, Int64}}(SciMLStructures.Tunable(), (1, 1)))])

julia> gradient(sol) do sol # incorrect, should be [1.0, 1.0, 0.0]
         sum(pf4(sol))
       end
((u = nothing, u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = (tunable = ([2.0, 1.0, 0.0],), discrete = nothing, constant = nothing, dependent = nothing, nonnumeric = nothing, dependent_update_iip = nothing, dependent_update_oop = nothing), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

The gradients returned when passed a vector of parameters is incorrect.

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
(SciMLSensitivity) pkg> st                                                                                                                                                [25/1807]
Project SciMLSensitivity v7.57.0
Status `~/arpa/jsmo/clone/SciMLSensitivity.jl/Project.toml`
⌅ [47edcb42] ADTypes v0.2.7
  [79e6a3ab] Adapt v4.0.4
  [4fba245c] ArrayInterface v7.10.0
  [082447d4] ChainRules v1.66.0 `https://github.com/JuliaDiff/ChainRules.jl.git#main`
  [d360d2e6] ChainRulesCore v1.23.0
  [2b5f629d] DiffEqBase v6.149.1 `https://github.com/DhairyaLGandhi/DiffEqBase.jl.git#dg/kw`
  [459566f4] DiffEqCallbacks v3.6.2
  [77a26b50] DiffEqNoiseProcess v5.21.0
  [31c24e10] Distributions v0.25.108
  [da5c29d0] EllipsisNotation v1.8.0
  [7da242da] Enzyme v0.12.6
  [6a86dc24] FiniteDiff v2.23.1
  [f6369f11] ForwardDiff v0.10.36
  [f62d2435] FunctionProperties v0.1.2
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [d9f16b24] Functors v0.4.10
  [46192b85] GPUArraysCore v0.1.6
  [7ed4a6bd] LinearSolve v2.30.0
  [961ee093] ModelingToolkit v9.13.0
  [1dea7af3] OrdinaryDiffEq v6.76.0
  [d96e819e] Parameters v0.12.3
  [d236fae5] PreallocationTools v0.4.21
  [1fd47b50] QuadGK v2.9.4
  [e6cf234a] RandomNumbers v1.5.3
  [731186ca] RecursiveArrayTools v3.18.1 `../RecursiveArrayTools.jl`
  [189a3867] Reexport v1.2.2
  [37e2e3b7] ReverseDiff v1.15.3
  [0bca4576] SciMLBase v2.36.1 `https://github.com/DhairyaLGandhi/SciMLBase.jl#dg/obsfn`
  [c0aeaf25] SciMLOperators v0.3.8
  [53ae85a6] SciMLStructures v1.2.0
⌃ [47a9eef4] SparseDiffTools v2.18.0
  [90137ffa] StaticArrays v1.9.3
  [1e83bf80] StaticArraysCore v1.4.2
  [789caeaf] StochasticDiffEq v6.65.1
  [2efcf032] SymbolicIndexingInterface v0.3.21
  [9f7883ad] Tracker v0.2.34
  [781d530d] TruncatedStacktraces v1.4.0
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
  [d6f4376e] Markdown
  [9a3f8284] Random
  [10745b16] Statistics v1.10.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see 
why use `status --outdated`
  • Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
  • Output of versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 64 × AMD EPYC 7513 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
  Threads: 1 on 64 virtual cores

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions