Skip to content

Commit 1c615be

Browse files
Merge branch 'SciML:master' into dg/paramgrad
2 parents f020500 + 45612a9 commit 1c615be

File tree

14 files changed

+495
-35
lines changed

14 files changed

+495
-35
lines changed

.github/workflows/Downgrade.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
runs-on: ubuntu-latest
1616
strategy:
1717
matrix:
18-
version: ['1']
18+
version: ['min']
1919
steps:
2020
- uses: actions/checkout@v4
2121
- uses: julia-actions/setup-julia@v2

Project.toml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.79.0"
4+
version = "2.82.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -60,11 +60,10 @@ ChainRules = "1.58.0"
6060
ChainRulesCore = "1.18"
6161
CommonSolve = "0.2.4"
6262
ConstructionBase = "1.5"
63-
DataFrames = "1.6"
6463
Distributed = "1.10"
6564
DocStringExtensions = "0.9"
6665
EnumX = "1"
67-
ForwardDiff = "0.10.36"
66+
ForwardDiff = "0.10.36, 1"
6867
FunctionWrappersWrappers = "0.1.3"
6968
IteratorInterfaceExtensions = "^1"
7069
LinearAlgebra = "1.10"
@@ -73,7 +72,6 @@ MLStyle = "0.4.17"
7372
Makie = "0.20, 0.21, 0.22"
7473
Markdown = "1.10"
7574
Moshi = "0.3"
76-
NonlinearSolve = "3, 4"
7775
PartialFunctions = "1.1"
7876
PrecompileTools = "1.2"
7977
Preferences = "1.3"
@@ -99,12 +97,8 @@ julia = "1.10"
9997
[extras]
10098
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
10199
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
102-
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
103-
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
104100
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
105101
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
106-
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
107-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
108102
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
109103
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
110104
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
@@ -113,11 +107,10 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
113107
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
114108
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
115109
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
116-
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
117110
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
118111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
119112
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
120113
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
121114

122115
[targets]
123-
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "NonlinearSolve", "OrdinaryDiffEq", "ForwardDiff", "Tables", "MLStyle"]
116+
test = ["Aqua", "ForwardDiff", "MLStyle", "PartialFunctions", "Pkg", "Plots", "SafeTestsets", "Serialization", "StableRNGs", "StaticArrays", "Tables", "Test", "UnicodePlots", "Zygote"]

ext/SciMLBaseZygoteExt.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ function obs_grads(VA, sym, obs_idx, Δ)
116116
back(Δobs)
117117
end
118118

119+
function obs_grads2(VA::SciMLBase.NonlinearSolution, sym, obs_idx, Δ)
120+
y, back = Zygote.pullback(VA) do sol
121+
getindex.(Ref(sol), sym[obs_idx])
122+
end
123+
Δobs = Δ[obs_idx, :]
124+
back(Δobs)
125+
end
126+
119127
function obs_grads(VA, sym, ::Nothing, Δ)
120128
Zygote.nt_nothing(VA)
121129
end
@@ -154,6 +162,31 @@ end
154162
VA[sym], ODESolution_getindex_pullback
155163
end
156164

165+
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
166+
function NonlinearSolution_getindex_pullback(Δ)
167+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
168+
if is_observed(VA, sym)
169+
f = observed(VA, sym)
170+
p = parameter_values(VA)
171+
u = state_values(VA)
172+
_, back = Zygote.pullback(u, p) do u, p
173+
f.f_oop(u, p)
174+
end
175+
gs = back(Δ)
176+
((u = gs[1], prob = (p = gs[2],),), nothing)
177+
elseif i === nothing
178+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
179+
else
180+
VA = recursivecopy(VA)
181+
recursivefill!(VA, zero(eltype(VA)))
182+
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
183+
copyto!(v, Δ)
184+
(VA, nothing)
185+
end
186+
end
187+
VA[sym], NonlinearSolution_getindex_pullback
188+
end
189+
157190
@adjoint function ODESolution{
158191
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u,
159192
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,

src/problems/rode_problems.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ When a keyword argument is `nothing`, the default behaviour of the solver is use
110110
* `alias::Union{Bool, Nothing}`: sets all fields of the `RODEAliasSpecifier` to `alias`
111111
112112
"""
113-
114113
struct RODEAliasSpecifier <: AbstractAliasSpecifier
115114
alias_p::Union{Bool, Nothing}
116115
alias_f::Union{Bool, Nothing}

src/remake.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ function remake(
191191
props = @delete props._func_cache
192192
props = @insert props._func_cache = forig._func_cache
193193
end
194-
194+
195195
args = (args..., f2)
196196
end
197197
end

src/scimlfunctions.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3067,7 +3067,7 @@ function ImplicitDiscreteFunction{iip, specialize}(f;
30673067
f.observed :
30683068
DEFAULT_OBSERVED,
30693069
resid_prototype = __has_resid_prototype(f) ?
3070-
f.resid_prototype :
3070+
f.resid_prototype :
30713071
nothing,
30723072
sys = __has_sys(f) ? f.sys : nothing,
30733073
initialization_data = __has_initialization_data(f) ? f.initialization_data :
@@ -3083,11 +3083,12 @@ function ImplicitDiscreteFunction{iip, specialize}(f;
30833083
analytic,
30843084
observed,
30853085
sys,
3086-
resid_prototype,
3086+
resid_prototype,
30873087
initialization_data)
30883088
else
30893089
ImplicitDiscreteFunction{
3090-
iip, specialize, typeof(_f), typeof(analytic), typeof(observed), typeof(sys), typeof(resid_prototype),
3090+
iip, specialize, typeof(_f), typeof(analytic),
3091+
typeof(observed), typeof(sys), typeof(resid_prototype),
30913092
typeof(initialization_data)}(
30923093
_f, analytic, observed, sys, resid_prototype, initialization_data)
30933094
end
@@ -3097,7 +3098,9 @@ function ImplicitDiscreteFunction{iip}(f; kwargs...) where {iip}
30973098
ImplicitDiscreteFunction{iip, FullSpecialize}(f; kwargs...)
30983099
end
30993100
ImplicitDiscreteFunction{iip}(f::ImplicitDiscreteFunction; kwargs...) where {iip} = f
3100-
function ImplicitDiscreteFunction(f; resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing, kwargs...)
3101+
function ImplicitDiscreteFunction(
3102+
f; resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing,
3103+
kwargs...)
31013104
ImplicitDiscreteFunction{isinplace(f, 5), FullSpecialize}(f; resid_prototype, kwargs...)
31023105
end
31033106
ImplicitDiscreteFunction(f::ImplicitDiscreteFunction; kwargs...) = f
@@ -3107,11 +3110,13 @@ function unwrapped_f(f::ImplicitDiscreteFunction, newf = unwrapped_f(f.f))
31073110

31083111
if specialize === NoSpecialize
31093112
ImplicitDiscreteFunction{isinplace(f, 5), specialize, Any, Any, Any,
3110-
Any, Any, Any}(newf, f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
3113+
Any, Any, Any}(
3114+
newf, f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
31113115
else
31123116
ImplicitDiscreteFunction{isinplace(f, 5), specialize, typeof(newf),
31133117
typeof(f.analytic),
3114-
typeof(f.observed), typeof(f.sys), typeof(resid_prototype), typeof(f.initialization_data)}(newf,
3118+
typeof(f.observed), typeof(f.sys), typeof(resid_prototype), typeof(f.initialization_data)}(
3119+
newf,
31153120
f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
31163121
end
31173122
end

src/solutions/rode_solutions.jl

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
3232
exited due to an error. For more details, see
3333
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
3434
"""
35-
struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, S,
35+
struct RODESolution{T, N, uType, uType2, DType, tType, randType, discType, P, A, IType, S,
3636
AC <: Union{Nothing, Vector{Int}}, V} <:
3737
AbstractRODESolution{T, N, uType}
3838
u::uType
3939
u_analytic::uType2
4040
errors::DType
4141
t::tType
4242
W::randType
43+
discretes::discType
4344
prob::P
4445
alg::A
4546
interp::IType
@@ -63,9 +64,10 @@ function ConstructionBase.setproperties(sol::RODESolution, patch::NamedTuple)
6364
patch = merge(getproperties(sol), patch)
6465
return RODESolution{
6566
T, N, typeof(patch.u), typeof(patch.u_analytic), typeof(patch.errors),
66-
typeof(patch.t), typeof(patch.W), typeof(patch.prob), typeof(patch.alg), typeof(patch.interp),
67+
typeof(patch.t), typeof(patch.W), typeof(patch.discretes),
68+
typeof(patch.prob), typeof(patch.alg), typeof(patch.interp),
6769
typeof(patch.stats), typeof(patch.alg_choice), typeof(patch.saved_subsystem)}(
68-
patch.u, patch.u_analytic, patch.errors, patch.t, patch.W,
70+
patch.u, patch.u_analytic, patch.errors, patch.t, patch.W, patch.discretes,
6971
patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
7072
patch.alg_choice, patch.retcode, patch.seed, patch.saved_subsystem)
7173
end
@@ -120,16 +122,28 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
120122
Base.depwarn(msg, :build_solution)
121123
end
122124

125+
ps = parameter_values(prob)
126+
if has_sys(prob.f)
127+
sswf = if saved_subsystem === nothing
128+
prob.f.sys
129+
else
130+
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
131+
end
132+
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
133+
else
134+
discretes = nothing
135+
end
123136
if has_analytic(f)
124137
u_analytic = Vector{typeof(prob.u0)}()
125138
errors = Dict{Symbol, real(eltype(prob.u0))}()
126139
sol = RODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
127-
typeof(W),
140+
typeof(W), typeof(discretes),
128141
typeof(prob), typeof(alg), typeof(interp), typeof(stats),
129142
typeof(alg_choice), typeof(saved_subsystem)}(u,
130143
u_analytic,
131144
errors,
132145
t, W,
146+
discretes,
133147
prob,
134148
alg,
135149
interp,
@@ -149,15 +163,37 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
149163
return sol
150164
else
151165
return RODESolution{T, N, typeof(u), Nothing, Nothing, typeof(t),
152-
typeof(W), typeof(prob), typeof(alg), typeof(interp),
166+
typeof(W), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
153167
typeof(stats), typeof(alg_choice), typeof(saved_subsystem)}(
154-
u, nothing, nothing, t, W,
168+
u, nothing, nothing, t, W, discretes,
155169
prob, alg, interp,
156170
dense, 0, stats,
157171
alg_choice, retcode, seed, saved_subsystem)
158172
end
159173
end
160174

175+
function save_discretes!(sol::AbstractRODESolution, t, vals, timeseries_idx)
176+
RecursiveArrayTools.has_discretes(sol) || return
177+
disc = RecursiveArrayTools.get_discretes(sol)
178+
_save_discretes_internal!(disc[timeseries_idx], t, vals)
179+
end
180+
181+
function get_interpolated_discretes(sol::AbstractRODESolution, t, deriv, continuity)
182+
is_parameter_timeseries(sol) == Timeseries() || return nothing
183+
184+
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
185+
interp_discs = map(discs) do partition
186+
hold_discrete(partition.u, partition.t, t)
187+
end
188+
return ParameterTimeseriesCollection(interp_discs, parameter_values(discs))
189+
end
190+
191+
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
192+
T1, T2, T3, T4, T5, T6, T7,
193+
S <: RODESolution{T1, T2, T3, T4, T5, T6, T7, <:ParameterTimeseriesCollection}}
194+
Timeseries()
195+
end
196+
161197
function calculate_solution_errors!(sol::AbstractRODESolution; fill_uanalytic = true,
162198
timeseries_errors = true, dense_errors = true)
163199
if sol.prob.f isa Tuple

test/aqua.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
using Test
22
using SciMLBase
33
using Aqua
4+
using Pkg
5+
6+
# yes this is horrible, we'll fix it when Pkg or Base provides a decent API
7+
manifest = Pkg.Types.EnvCache().manifest
8+
# these are good sentinels to test whether someone has added a heavy SciML package to the test deps
9+
if haskey(manifest.deps, "NonlinearSolveBase") || haskey(manifest.deps, "DiffEqBase")
10+
error("Don't put Downstream Packages in non Downstream CI")
11+
end
412

513
# https://github.com/JuliaArrays/FillArrays.jl/pull/163
614
@test isempty(detect_ambiguities(SciMLBase))

test/downstream/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
3+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
34
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
45
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
56
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -30,6 +31,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3031

3132
[compat]
3233
BoundaryValueDiffEq = "5"
34+
DataFrames = "1.6"
3335
DelayDiffEq = "5"
3436
DiffEqCallbacks = "3, 4"
3537
ForwardDiff = "0.10"

test/downstream/comprehensive_indexing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,19 @@ end
977977
sol = solve(prob, ImplicitEM())
978978
@test sol[sym] sol(sol.t .- sol.ps[delay]; idxs = original)
979979
end
980+
981+
@testset "RODESolutions save discretes" begin
982+
@parameters k(t)
983+
@variables A(t)
984+
function affect2!(integ, u, p, ctx)
985+
integ.ps[p.k] += 1.0
986+
end
987+
db = 1.0 => (affect2!, [], [k], [k], nothing)
988+
989+
@named ssys = SDESystem(D(A) ~ k * A, [0.0], t, [A], [k], discrete_events = db)
990+
ssys = complete(ssys)
991+
prob = SDEProblem(ssys, [A => 1.0], (0.0, 4.0), [k => 1.0])
992+
sol = solve(prob, RI5())
993+
@test sol[k] isa AbstractVector
994+
@test sol[k] == [1.0, 2.0, 3.0, 4.0]
995+
end

0 commit comments

Comments
 (0)