Skip to content

Commit 054dc29

Browse files
Improve SciMLStructures support
Fixes #1233
1 parent 4579890 commit 054dc29

File tree

2 files changed

+89
-5
lines changed

2 files changed

+89
-5
lines changed

src/gauss_adjoint.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,16 +421,16 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
421421
elseif sensealg.autojacvec isa MooncakeVJP
422422
pf = get_pf(sensealg.autojacvec, prob, f)
423423
paramjac_config = get_paramjac_config(
424-
MooncakeLoaded(), sensealg.autojacvec, pf, p, f, y, tspan[2])
424+
MooncakeLoaded(), sensealg.autojacvec, pf, tunables, f, y, tspan[2])
425425
pJ = nothing
426426
elseif isautojacvec # Zygote
427427
paramjac_config = nothing
428428
pf = nothing
429429
pJ = nothing
430430
else
431-
pf = SciMLBase.ParamJacobianWrapper(unwrappedf, tspan[1], y)
431+
pf = SciMLBase.ParamJacobianWrapper((du,u,p,t)->unwrappedf(du,u,repack(p),t), tspan[1], y)
432432
pJ = similar(u0, length(u0), numparams)
433-
paramjac_config = build_param_jac_config(sensealg, pf, y, p)
433+
paramjac_config = build_param_jac_config(sensealg, pf, y, tunables)
434434
end
435435

436436
cpsol = sol
@@ -460,7 +460,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
460460
else
461461
pf.t = t
462462
pf.u = y
463-
jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
463+
jacobian!(pJ, pf, tunables, f_cache, sensealg, paramjac_config)
464464
end
465465
mul!(out', λ', pJ)
466466
elseif sensealg.autojacvec isa ReverseDiffVJP
@@ -610,7 +610,7 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
610610
res .+= out
611611
iλ .= zero(eltype(iλ))
612612
end
613-
end
613+
end
614614

615615
return state_values(adj_sol)[end], __maybe_adjoint(res)
616616
end

test/scimlstructures_interface.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,87 @@ Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunable
7878
sol = solve(newprob, Tsit5())
7979
return sum(sol.u[end])
8080
end
81+
82+
using OrdinaryDiffEq
83+
using StableRNGs, Lux
84+
using ComponentArrays
85+
using SciMLSensitivity
86+
import SciMLStructures as SS
87+
using Zygote
88+
using Parameters
89+
using ADTypes
90+
using Test
91+
92+
mutable struct myparam{M,P,S}
93+
model::M
94+
ps ::P
95+
st ::S
96+
α :: Float64
97+
β :: Float64
98+
γ :: Float64
99+
end
100+
101+
SS.isscimlstructure(::myparam) = true
102+
SS.ismutablescimlstructure(::myparam) = true
103+
SS.hasportion(::SS.Tunable, ::myparam) = true
104+
function SS.canonicalize(::SS.Tunable, p::myparam)
105+
buffer = copy(p.ps)
106+
repack = let p = p
107+
function repack(newbuffer)
108+
SS.replace(SS.Tunable(), p, newbuffer)
109+
end
110+
end
111+
return buffer, repack, false
112+
end
113+
function SS.replace(::SS.Tunable, p::myparam, newbuffer)
114+
return myparam(p.model, newbuffer, p.st, p.α, p.β, p.γ)
115+
end
116+
function SS.replace!(::SS.Tunable, p::myparam, newbuffer)
117+
p.ps = newbuffer
118+
return p
119+
end
120+
function initialize()
121+
# Defining the neural network
122+
U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
123+
rng = StableRNG(1111)
124+
_para,st = Lux.setup(rng,U)
125+
_para = ComponentArray(_para)
126+
# Setting the parameters
127+
α = 0.5
128+
β = 0.1
129+
γ = 0.01
130+
return myparam(U,_para,st,α,β,γ)
131+
end
132+
function UDE_model!(du, u, p, t)
133+
# Extracting parameters
134+
Parameters.@unpack model, ps, st, α, β, γ = p
135+
o = model(u,ps, st)[1][1]
136+
du[1] = o * α * u[1] + β * u[2] + γ * u[3]
137+
du[2] = -α * u[1] + β * u[2] - γ * u[3]
138+
du[3] = α * u[1] - β * u[2] + γ * u[3]
139+
nothing
140+
end
141+
142+
p = initialize()
143+
function run_diff(ps)
144+
u01 = [1.0, 0.0, 0.0]
145+
tspan = (0.0, 10.0)
146+
prob = ODEProblem(UDE_model!, u01, tspan, ps)
147+
sol = solve(prob, Rosenbrock23(), saveat = 0.1)
148+
return sol.u |> last |> sum
149+
end
150+
151+
run_diff(initialize())
152+
@test !iszero(Zygote.gradient(run_diff, initialize())[1].ps)
153+
154+
function run_diff(ps,sensealg)
155+
u01 = [1.0, 0.0, 0.0]
156+
tspan = (0.0, 10.0)
157+
prob = ODEProblem(UDE_model!, u01, tspan, ps)
158+
sol = solve(prob, Rosenbrock23(), saveat = 0.1, sensealg=sensealg)
159+
return sol.u |> last |> sum
160+
end
161+
162+
run_diff(initialize())
163+
@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps)
164+
@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec=false))[1].ps)

0 commit comments

Comments
 (0)