Skip to content

Commit cdaa2c7

Browse files
chore: merge upstream
1 parent de63cf9 commit cdaa2c7

File tree

7 files changed

+78
-34
lines changed

7 files changed

+78
-34
lines changed

Project.toml

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,48 @@
11
name = "SciMLSensitivity"
22
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
33
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
4-
version = "7.76.0"
4+
version = "7.74.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
99
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1010
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
11+
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
1112
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1214
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1315
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1416
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
17+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1518
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1619
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1720
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1821
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
22+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
23+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1924
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2025
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
2126
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
2227
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
2328
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
29+
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
2430
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2531
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
32+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
2633
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
34+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
35+
ModelingToolkitNeuralNets = "f162e290-f571-43a6-83d9-22ecc16da15f"
36+
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
37+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
38+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
39+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
40+
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
41+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
42+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2743
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
44+
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
45+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2846
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
2947
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
3048
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -36,9 +54,13 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3654
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
3755
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
3856
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
57+
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
58+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3959
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4060
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
4161
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
62+
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
63+
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
4264
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4365
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4466
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -50,17 +72,17 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
5072
SciMLSensitivityMooncakeExt = "Mooncake"
5173

5274
[compat]
53-
ADTypes = "1.9"
75+
ADTypes = "1.9, 1.13"
5476
Accessors = "0.1.36"
5577
Adapt = "1.0, 2.0, 3.0, 4"
56-
AlgebraicMultigrid = "1"
78+
AlgebraicMultigrid = "0.6.0"
5779
Aqua = "0.8.4"
5880
ArrayInterface = "7"
5981
Calculus = "0.5.1"
6082
ChainRulesCore = "0.10.7, 1"
6183
ComponentArrays = "0.15.5"
6284
DelayDiffEq = "5.43.2"
63-
DiffEqBase = "6.166.1"
85+
DiffEqBase = "6.151.1"
6486
DiffEqCallbacks = "4"
6587
DiffEqNoiseProcess = "5.19"
6688
Distributed = "1"
@@ -85,7 +107,6 @@ NonlinearSolve = "3.0.1, 4"
85107
Optimization = "4"
86108
OptimizationOptimisers = "0.3"
87109
OrdinaryDiffEq = "6.81.1"
88-
OrdinaryDiffEqCore = "1"
89110
Pkg = "1.10"
90111
PreallocationTools = "0.4.4"
91112
QuadGK = "2.9.1"
@@ -95,7 +116,7 @@ RecursiveArrayTools = "3.27.2"
95116
Reexport = "1.0"
96117
ReverseDiff = "1.15.1"
97118
SafeTestsets = "0.1.0"
98-
SciMLBase = "2.79"
119+
SciMLBase = "2.51.4"
99120
SciMLJacobianOperators = "0.1"
100121
SciMLOperators = "0.3"
101122
SciMLStructures = "1.3"
@@ -120,7 +141,6 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
120141
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
121142
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
122143
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
123-
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
124144
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
125145
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
126146
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
@@ -135,4 +155,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
135155
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
136156

137157
[targets]
138-
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
158+
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]

src/adjoint_common.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -749,16 +749,16 @@ function out_and_ts(_ts, duplicate_iterator_times, sol)
749749
return out, ts
750750
end
751751

752-
if !hasmethod(Zygote.adjoint,
753-
Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty),
754-
SciMLBase.AbstractTimeseriesSolution, Val{:u}})
755-
Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution,
756-
::Val{:u})
757-
function solu_adjoint(Δ)
758-
zerou = zero(sol.prob.u0)
759-
= @. ifelse=== nothing, (zerou,), Δ)
760-
(SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
761-
end
762-
sol.u, solu_adjoint
763-
end
764-
end
752+
# if !hasmethod(Zygote.adjoint,
753+
# Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty),
754+
# SciMLBase.AbstractTimeseriesSolution, Val{:u}})
755+
# Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution,
756+
# ::Val{:u})
757+
# function solu_adjoint(Δ)
758+
# zerou = zero(sol.prob.u0)
759+
# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
760+
# (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
761+
# end
762+
# sol.u, solu_adjoint
763+
# end
764+
# end

src/gauss_adjoint.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ end
224224
## Force recompile mode until vjps are specialized to handle this!!!
225225
f = if sol.prob.f isa ODEFunction &&
226226
sol.prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
227-
ODEFunction{isinplace(sol.prob), true}(unwrapped_f(sol.prob.f))
227+
# ODEFunction{isinplace(sol.prob), true}(unwrapped_f(sol.prob.f))
228+
ODEFunction{false, true}(unwrapped_f(sol.prob.f))
228229
else
229230
sol.prob.f
230231
end
@@ -487,8 +488,19 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
487488
ReverseDiff.reverse_pass!(tape)
488489
copyto!(vec(out), ReverseDiff.deriv(tp))
489490
elseif sensealg.autojacvec isa ZygoteVJP
491+
# global gf = f
492+
# global gy = y
493+
# global gtunables = tunables
494+
# global grepack = repack
495+
# global gt = t
496+
# @show f(y, tunables, t)
497+
c = Zygote.bufferfrom(y)
498+
c = copy(y)
499+
# @show f(c, y, repack(tunables), t)
490500
_dy, back = Zygote.pullback(tunables) do tunables
491-
vec(f(y, repack(tunables), t))
501+
c = Zygote.bufferfrom(y)
502+
f(c, y, repack(tunables), t)
503+
vec(copy(c))
492504
end
493505
tmp = back(λ)
494506
if tmp[1] === nothing
@@ -583,6 +595,9 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
583595

584596
tstops = ischeckpointing(sensealg, sol) ? checkpoints : similar(current_time(sol), 0)
585597

598+
@show adj_prob.f.initialization_data
599+
@show kwargs
600+
@show adj_prob.kwargs
586601
adj_sol = solve(
587602
adj_prob, alg; abstol = abstol, reltol = reltol, save_everystep = false,
588603
save_start = false, save_end = true, saveat = eltype(state_values(sol, 1))[], tstops = tstops,

src/quadrature_adjoint.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
176176
adj_prob = adj_sol.prob
177177
(; f, tspan) = prob
178178
p = parameter_values(prob)
179+
tunables, _, _ = canonicalize(Tunable(), p)
179180
u0 = state_values(prob)
180181
numparams = length(p)
181182
y = zero(state_values(prob))
@@ -252,6 +253,9 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
252253
end
253254
AdjointSensitivityIntegrand(sol, adj_sol, p, y, λ, pf, f_cache, pJ, paramjac_config,
254255
sensealg, dgdp_cache, dgdp)
256+
257+
# AdjointSensitivityIntegrand(sol, adj_sol, tunables, y, λ, pf, f_cache, pJ, paramjac_config,
258+
# sensealg, dgdp_cache, dgdp)
255259
end
256260

257261
# out = λ df(u, p, t)/dp at u=y, p=p, t=t
@@ -285,13 +289,15 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
285289
copyto!(vec(out), ReverseDiff.deriv(tp))
286290
elseif sensealg.autojacvec isa ZygoteVJP
287291
_dy, back = Zygote.pullback(p) do p
292+
# @show f(y, p, t)
288293
vec(f(y, p, t))
289294
end
290295
tmp = back(λ)
296+
# @show tmp
291297
if tmp[1] === nothing
292298
out[:] .= 0
293299
else
294-
out[:] .= vec(tmp[1])
300+
out[:] .= vec(tmp[1].tunable)
295301
end
296302
elseif sensealg.autojacvec isa MooncakeVJP
297303
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
@@ -330,6 +336,7 @@ function (S::AdjointSensitivityIntegrand)(out, t)
330336
end
331337

332338
function (S::AdjointSensitivityIntegrand)(t)
339+
# out = similar(S.p.tunable)
333340
out = similar(S.p)
334341
out .= false
335342
S(out, t)
@@ -359,7 +366,9 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi
359366
res, err = quadgk(integrand, sol.prob.tspan[1], sol.prob.tspan[2],
360367
atol = abstol, rtol = reltol)
361368
else
362-
res = zero(integrand.p)'
369+
tunables, _, _ = canonicalize(Tunable(), integrand.p)
370+
# res = zero(integrand.p)'
371+
res = zero(tunables)'
363372

364373
# handle discrete dgdp contributions
365374
if dgdp_discrete !== nothing

test/adjoint.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -973,8 +973,8 @@ for iip in [true, false]
973973
p = [0.04, 3e7, 1e4]
974974

975975
prob_singular_mm = ODEProblem(f, [1.0, 0.0, 1.0], (0.0, 100), p)
976-
sol_singular_mm = solve(prob_singular_mm, Rodas4(autodiff = false),
977-
reltol = 1e-12, abstol = 1e-12)
976+
sol_singular_mm = solve(prob_singular_mm, FBDF(autodiff = false),
977+
reltol = 1e-12, abstol = 1e-12, initializealg = BrownFullBasicInit())
978978
ts = [50, sol_singular_mm.t[end]]
979979
dg_singular(out, u, p, t, i) = (fill!(out, 0); out[end] = 1)
980980
_, res = adjoint_sensitivities(sol_singular_mm, alg, t = ts,

test/parameter_handling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ x = ones(Float32, 2, 3)
1212

1313
nlprob(u, p) = first(model_nls(u, p, st_nls)) .- u
1414

15-
prob = NonlinearProblem(nlprob, zeros(2, 3), ca)
15+
prob = NonlinearProblem(nlprob, zeros(2, 3), ps)
1616

1717
@test_nowarn solve(prob, NewtonRaphson())
1818

test/runtests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ end
8181

8282
if GROUP == "All" || GROUP == "Core6"
8383
@testset "Core 6" begin
84-
@time @safetestset "Enzyme Closures" include("enzyme_closure.jl")
85-
@time @safetestset "Complex Matrix FiniteDiff Adjoint" include("complex_matrix_finitediff.jl")
86-
@time @safetestset "Null Parameters" include("null_parameters.jl")
87-
@time @safetestset "Forward Mode Prob Kwargs" include("forward_prob_kwargs.jl")
84+
# @time @safetestset "Enzyme Closures" include("enzyme_closure.jl")
85+
# @time @safetestset "Complex Matrix FiniteDiff Adjoint" include("complex_matrix_finitediff.jl")
86+
# @time @safetestset "Null Parameters" include("null_parameters.jl")
87+
# @time @safetestset "Forward Mode Prob Kwargs" include("forward_prob_kwargs.jl")
8888
@time @safetestset "Steady State Adjoint" include("steady_state.jl")
89-
@time @safetestset "Concrete Solve Derivatives of Second Order ODEs" include("second_order_odes.jl")
90-
@time @safetestset "Parameter Compatibility Errors" include("parameter_compatibility_errors.jl")
89+
# @time @safetestset "Concrete Solve Derivatives of Second Order ODEs" include("second_order_odes.jl")
90+
# @time @safetestset "Parameter Compatibility Errors" include("parameter_compatibility_errors.jl")
9191
end
9292
end
9393

0 commit comments

Comments
 (0)