Skip to content

Commit 23f6936

Browse files
Merge pull request #943 from AayushSabharwal/as/fix-tests
fix: fix remake autodiff tests and Zygote adjoint
2 parents 6d1c9e6 + 5e2e25e commit 23f6936

15 files changed

+117
-215
lines changed

.github/workflows/Downstream.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ jobs:
4949
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
5050
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
5151
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
52+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
53+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core7}
5254
- {user: SciML, repo: Catalyst.jl, group: All}
5355

5456
steps:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ SciMLBasePartialFunctionsExt = "PartialFunctions"
5050
SciMLBasePyCallExt = "PyCall"
5151
SciMLBasePythonCallExt = "PythonCall"
5252
SciMLBaseRCallExt = "RCall"
53-
SciMLBaseZygoteExt = "Zygote"
53+
SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"]
5454

5555
[compat]
5656
ADTypes = "0.2.5,1.0.0"

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module SciMLBaseChainRulesCoreExt
22

33
using SciMLBase
4+
using SciMLBase: getobserved
45
import ChainRulesCore
5-
import ChainRulesCore: NoTangent, @non_differentiable
6+
import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad
67
using SymbolicIndexingInterface
78

89
function ChainRulesCore.rrule(
@@ -15,52 +16,28 @@ function ChainRulesCore.rrule(
1516
j::Integer)
1617
function ODESolution_getindex_pullback(Δ)
1718
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
18-
if i === nothing
19+
du, dprob = if i === nothing
1920
getter = getobserved(VA)
2021
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
21-
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
22-
dp = grz[3] # pullback for p
22+
du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)]
23+
dp = grz[4] # pullback for p
24+
if dp == NoTangent()
25+
dp = zero_tangent(parameter_values(VA.prob))
26+
end
2327
dprob = remake(VA.prob, p = dp)
24-
T = eltype(eltype(VA.u))
25-
N = length(VA.prob.p)
26-
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing,
27-
typeof(dprob), Nothing, Nothing, Nothing, Nothing}(du, nothing,
28-
nothing, nothing, nothing, dprob, nothing, nothing,
29-
VA.dense, 0, nothing, nothing, VA.retcode)
30-
(NoTangent(), Δ′, NoTangent(), NoTangent())
28+
du, dprob
3129
else
3230
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
3331
zero(VA.u[1]) for m in 1:length(VA.u)]
34-
dp = zero(VA.prob.p)
32+
dp = zero_tangent(VA.prob.p)
3533
dprob = remake(VA.prob, p = dp)
36-
Δ′ = ODESolution{
37-
T,
38-
N,
39-
typeof(du),
40-
Nothing,
41-
Nothing,
42-
typeof(VA.t),
43-
typeof(VA.k),
44-
typeof(dprob),
45-
typeof(VA.alg),
46-
typeof(VA.interp),
47-
typeof(VA.alg_choice),
48-
typeof(VA.stats)
49-
}(du,
50-
nothing,
51-
nothing,
52-
VA.t,
53-
VA.k,
54-
dprob,
55-
VA.alg,
56-
VA.interp,
57-
VA.dense,
58-
0,
59-
VA.stats,
60-
VA.alg_choice,
61-
VA.retcode)
62-
(NoTangent(), Δ′, NoTangent(), NoTangent())
34+
du, dprob
6335
end
36+
T = eltype(eltype(du))
37+
N = ndims(eltype(du)) + 1
38+
Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob,
39+
VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode)
40+
(NoTangent(), Δ′, NoTangent(), NoTangent())
6441
end
6542
VA[sym, j], ODESolution_getindex_pullback
6643
end

ext/SciMLBaseZygoteExt.jl

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SciMLBaseZygoteExt
33
using Zygote
44
using Zygote: @adjoint, pullback
55
import Zygote: literal_getproperty
6+
import ChainRulesCore
67
using SciMLBase
78
using SciMLBase: ODESolution, remake,
89
getobserved, build_solution, EnsembleSolution,
@@ -40,31 +41,9 @@ import SciMLStructures
4041
VA[i, j], ODESolution_getindex_pullback
4142
end
4243

43-
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
44-
function ODESolution_getindex_pullback(Δ)
45-
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
46-
du, dprob = if i === nothing
47-
getter = getobserved(VA)
48-
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
49-
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
50-
dp = grz[3] # pullback for p
51-
dprob = remake(VA.prob, p = dp)
52-
du, dprob
53-
else
54-
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
55-
zero(VA.u[1]) for m in 1:length(VA.u)]
56-
dp = zero(VA.prob.p)
57-
dprob = remake(VA.prob, p = dp)
58-
du, dprob
59-
end
60-
T = eltype(eltype(VA.u))
61-
N = ndims(VA)
62-
Δ′ = ODESolution{T, N}(du, nothing, nothing,
63-
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
64-
VA.alg_choice, VA.retcode)
65-
(Δ′, nothing, nothing)
66-
end
67-
VA[sym, j], ODESolution_getindex_pullback
44+
@adjoint function Base.getindex(VA::ODESolution, sym, j::Integer)
45+
res, pullback = ChainRulesCore.rrule(Zygote.ZygoteRuleConfig(), getindex, VA, sym, j)
46+
return res, Base.tail pullback
6847
end
6948

7049
@adjoint function EnsembleSolution(sim, time, converged, stats)

src/problems/sde_problems.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,14 @@ function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem}
129129
function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed)
130130
if f isa AbstractSDEFunction
131131
iip = isinplace(f)
132+
if g !== f.g
133+
f = remake(f; g)
134+
end
135+
return SDEProblem{iip}(f, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
132136
else
133137
iip = isinplace(f, 4)
138+
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
134139
end
135-
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
136140
end
137141
end
138142

src/remake.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ end
958958
function Base.showerror(io::IO, err::CyclicDependencyError)
959959
println(io, "Detected cyclic dependency in initial values:")
960960
for (k, v) in err.varmap
961-
println(io, k, " => ", "v")
961+
println(io, k, " => ", v)
962962
end
963963
println(io, "While trying to solve for variables: ", err.vars)
964964
end
@@ -1085,10 +1085,6 @@ calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. R
10851085
the updated `newu0` and `newp`.
10861086
"""
10871087
function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
1088-
if hasmethod(symbolic_container, Tuple{typeof(root_indp)}) &&
1089-
(sc = symbolic_container(root_indp)) !== root_indp
1090-
return late_binding_update_u0_p(prob, sc, u0, p, t0, newu0, newp)
1091-
end
10921088
return newu0, newp
10931089
end
10941090

@@ -1099,7 +1095,7 @@ Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after
10991095
`root_indp`.
11001096
"""
11011097
function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
1102-
root_indp = prob
1098+
root_indp = get_root_indp(prob)
11031099
return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
11041100
end
11051101

src/scimlfunctions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4818,8 +4818,10 @@ for S in [:ODEFunction
48184818
end
48194819
end
48204820

4821+
const EMPTY_SYMBOLCACHE = SymbolCache()
4822+
48214823
function SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction)
4822-
has_sys(fn) ? fn.sys : SymbolCache()
4824+
has_sys(fn) ? fn.sys : EMPTY_SYMBOLCACHE
48234825
end
48244826

48254827
function SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym)

src/solutions/save_idxs.jl

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
4444
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
4545
end
4646

47-
function is_empty_indp(indp)
48-
isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) &&
49-
isempty(independent_variable_symbols(indp))
47+
function get_root_indp(indp)
48+
if hasmethod(symbolic_container, Tuple{typeof(indp)}) && (sc = symbolic_container(indp)) !== indp
49+
return get_root_indp(sc)
50+
end
51+
return indp
5052
end
5153

5254
# Everything from this point on is public API
@@ -105,17 +107,26 @@ struct SavedSubsystem{V, T, M, I, P, Q, C}
105107
partition_count::C
106108
end
107109

108-
function SavedSubsystem(indp, pobj, saved_idxs)
109-
# nothing saved
110-
if saved_idxs === nothing || isempty(saved_idxs)
110+
SavedSubsystem(indp, pobj, ::Nothing) = nothing
111+
112+
function SavedSubsystem(indp, pobj, idx::Int)
113+
_indp = get_root_indp(indp)
114+
if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
111115
return nothing
112116
end
117+
state_map = Dict(1 => idx)
118+
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
119+
end
113120

114-
# this is required because problems with no system have an empty `SymbolCache`
115-
# as their symbolic container.
116-
if is_empty_indp(indp)
121+
function SavedSubsystem(indp, pobj, saved_idxs::Union{AbstractArray, Tuple})
122+
_indp = get_root_indp(indp)
123+
if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
117124
return nothing
118125
end
126+
if eltype(saved_idxs) == Int
127+
state_map = Dict{Int, Int}(v => k for (k, v) in enumerate(saved_idxs))
128+
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
129+
end
119130

120131
# array state symbolics must be scalarized
121132
saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym
@@ -357,29 +368,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so
357368
The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case
358369
one is not required. `save_idxs` may be a scalar or `nothing`.
359370
"""
371+
get_save_idxs_and_saved_subsystem(prob, ::Nothing) = nothing, nothing
372+
function get_save_idxs_and_saved_subsystem(prob, save_idxs::Vector{Int})
373+
save_idxs, SavedSubsystem(prob, parameter_values(prob), save_idxs)
374+
end
375+
function get_save_idxs_and_saved_subsystem(prob, save_idx::Int)
376+
save_idx, SavedSubsystem(prob, parameter_values(prob), save_idx)
377+
end
360378
function get_save_idxs_and_saved_subsystem(prob, save_idxs)
361-
if save_idxs === nothing
362-
saved_subsystem = nothing
379+
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
380+
_save_idxs = (save_idxs,)
363381
else
364-
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
365-
_save_idxs = [save_idxs]
382+
_save_idxs = save_idxs
383+
end
384+
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
385+
if saved_subsystem !== nothing
386+
_save_idxs = get_saved_state_idxs(saved_subsystem)
387+
if isempty(_save_idxs)
388+
# no states to save
389+
save_idxs = Int[]
390+
elseif !(save_idxs isa AbstractArray) ||
391+
symbolic_type(save_idxs) != NotSymbolic()
392+
# only a single state to save, and save it as a scalar timeseries instead of
393+
# single-element array
394+
save_idxs = only(_save_idxs)
366395
else
367-
_save_idxs = save_idxs
368-
end
369-
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
370-
if saved_subsystem !== nothing
371-
_save_idxs = get_saved_state_idxs(saved_subsystem)
372-
if isempty(_save_idxs)
373-
# no states to save
374-
save_idxs = Int[]
375-
elseif !(save_idxs isa AbstractArray) ||
376-
symbolic_type(save_idxs) != NotSymbolic()
377-
# only a single state to save, and save it as a scalar timeseries instead of
378-
# single-element array
379-
save_idxs = only(_save_idxs)
380-
else
381-
save_idxs = _save_idxs
382-
end
396+
save_idxs = _save_idxs
383397
end
384398
end
385399

test/downstream/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ DelayDiffEq = "5"
3434
DiffEqCallbacks = "3, 4"
3535
ForwardDiff = "0.10"
3636
JumpProcesses = "9.10"
37-
ModelingToolkit = "9.64.1"
37+
ModelingToolkit = "9.64.3"
3838
ModelingToolkitStandardLibrary = "2.7"
3939
NonlinearSolve = "2, 3, 4"
4040
Optimization = "4"

test/downstream/adjoints.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ u0 = [lorenz1.x => 1.0,
2222
lorenz1.z => 0.0,
2323
lorenz2.x => 0.0,
2424
lorenz2.y => 1.0,
25-
lorenz2.z => 0.0,
26-
a => 2.0]
25+
lorenz2.z => 0.0]
2726

2827
p = [lorenz1.σ => 10.0,
2928
lorenz1.ρ => 28.0,
@@ -68,7 +67,7 @@ gs_ts, = Zygote.gradient(sol) do sol
6867
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
6968
end
7069

71-
@test all(map(x -> x == true_grad_vecsym, gs_ts))
70+
@test all(map(x -> x == true_grad_vecsym, gs_ts.u))
7271

7372
# BatchedInterface AD
7473
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0

0 commit comments

Comments
 (0)