@@ -9,30 +9,30 @@ using OrdinaryDiffEqTsit5
99using LinearAlgebra
1010
1111mutable struct SubproblemParameters{P, Q, R}
12- p::P # tunable
13- q::Q
14- r::R
12+ p::P # tunable
13+ q::Q
14+ r::R
1515end
1616
1717mutable struct Parameters{P, C}
18- subparams::P
19- coeffs::C # tunable matrix
18+ subparams::P
19+ coeffs::C # tunable matrix
2020end
2121
2222# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
2323# and `du[length(subparams)+1:end] .= coeffs * u`
2424function rhs!(du, u, p::Parameters, t)
25- for (i, subpars) in enumerate(p.subparams)
26- du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27- end
28- N = length(p.subparams)
29- mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
30- return nothing
25+ for (i, subpars) in enumerate(p.subparams)
26+ du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27+ end
28+ N = length(p.subparams)
29+ mul!(view(du, (N+ 1):(length(du))), p.coeffs, u)
30+ return nothing
3131end
3232
3333u = sin.(0.1:0.1:1.0)
3434subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
35- p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10]))
35+ p = Parameters(subparams, cos.([0.1i+ 0.33j for i in 1:5, j in 1:10]))
3636tspan = (0.0, 1.0)
3737
3838prob = ODEProblem(rhs!, u, tspan, p)
@@ -47,13 +47,12 @@ using SciMLSensitivity
4747
4848# 5 subparams[i].p, 50 elements in coeffs
4949function simulate_with_tunables(tunables)
50- subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r)
51- for (i, subpar) in enumerate(p.subparams)]
52- coeffs = reshape(tunables[6:end], size(p.coeffs))
53- newp = Parameters(subpars, coeffs)
54- newprob = remake(prob; p = newp)
55- sol = solve(newprob, Tsit5())
56- return sum(sol.u[end])
50+ subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
51+ coeffs = reshape(tunables[6:end], size(p.coeffs))
52+ newp = Parameters(subpars, coeffs)
53+ newprob = remake(prob; p = newp)
54+ sol = solve(newprob, Tsit5())
55+ return sum(sol.u[end])
5756end
5857```
5958
@@ -75,49 +74,46 @@ SS.ismutablescimlstructure(::Parameters) = true
7574SS.hasportion(::SS.Tunable, ::Parameters) = true
7675
7776function SS.canonicalize(::SS.Tunable, p::Parameters)
78- # concatenate all tunable values into a single vector
79- buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
80-
81- # repack takes a new vector of the same length as `buffer`, and constructs
82- # a new `Parameters` object using the values from the new vector for tunables
83- # and retaining old values for other parameters. This is exactly what replace does,
84- # so we can use that instead.
85- repack = let p = p
86- function repack(newbuffer)
87- SS.replace(SS.Tunable(), p, newbuffer)
88- end
77+ # concatenate all tunable values into a single vector
78+ buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
79+
80+ # repack takes a new vector of the same length as `buffer`, and constructs
81+ # a new `Parameters` object using the values from the new vector for tunables
82+ # and retaining old values for other parameters. This is exactly what replace does,
83+ # so we can use that instead.
84+ repack = let p = p
85+ function repack(newbuffer)
86+ SS.replace(SS.Tunable(), p, newbuffer)
8987 end
90- # the canonicalized vector, the repack function, and a boolean indicating
91- # whether the buffer aliases values in the parameter object (here, it doesn't)
92- return buffer, repack, false
88+ end
89+ # the canonicalized vector, the repack function, and a boolean indicating
90+ # whether the buffer aliases values in the parameter object (here, it doesn't)
91+ return buffer, repack, false
9392end
9493
9594function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
96- N = length(p.subparams) + length(p.coeffs)
97- @assert length(newbuffer) == N
98- subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
99- for (i, subpar) in enumerate(p.subparams)]
100- coeffs = reshape(
101- view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
102- return Parameters(subparams, coeffs)
95+ N = length(p.subparams) + length(p.coeffs)
96+ @assert length(newbuffer) == N
97+ subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
98+ coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
99+ return Parameters(subparams, coeffs)
103100end
104101
105102function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
106- N = length(p.subparams) + length(p.coeffs)
107- @assert length(newbuffer) == N
108- for (subpar, val) in zip(p.subparams, newbuffer)
109- subpar.p = val
110- end
111- copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
112- return p
103+ N = length(p.subparams) + length(p.coeffs)
104+ @assert length(newbuffer) == N
105+ for (subpar, val) in zip(p.subparams, newbuffer)
106+ subpar.p = val
107+ end
108+ copyto!(coeffs, view(newbuffer, (length(p.subparams)+ 1):length(newbuffer)))
109+ return p
113110end
114111```
115112
116113Now, we should be able to differentiate through the ODE solve.
117114
118115``` @example basic_tutorial
119- Zygote.gradient(
120- simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
116+ Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
121117```
122118
123119We can also implement a ` Constants ` portion to store the rest of the values:
@@ -126,30 +122,29 @@ We can also implement a `Constants` portion to store the rest of the values:
126122SS.hasportion(::SS.Constants, ::Parameters) = true
127123
128124function SS.canonicalize(::SS.Constants, p::Parameters)
129- buffer = mapreduce(vcat, p.subparams) do subpar
130- [subpar.q, subpar.r]
131- end
132- repack = let p = p
133- function repack(newbuffer)
134- SS.replace(SS.Constants(), p, newbuffer)
135- end
125+ buffer = mapreduce(vcat, p.subparams) do subpar
126+ [subpar.q, subpar.r]
127+ end
128+ repack = let p = p
129+ function repack(newbuffer)
130+ SS.replace(SS.Constants(), p, newbuffer)
136131 end
132+ end
137133
138- return buffer, repack, false
134+ return buffer, repack, false
139135end
140136
141137function SS.replace(::SS.Constants, p::Parameters, newbuffer)
142- subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i - 1], newbuffer[2i])
143- for i in eachindex(p.subparams)]
144- return Parameters(subpars, p.coeffs)
138+ subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
139+ return Parameters(subpars, p.coeffs)
145140end
146141
147142function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
148- for i in eachindex(p.subparams)
149- p.subparams[i].q = newbuffer[2i - 1]
150- p.subparams[i].r = newbuffer[2i]
151- end
152- return p
143+ for i in eachindex(p.subparams)
144+ p.subparams[i].q = newbuffer[2i- 1]
145+ p.subparams[i].r = newbuffer[2i]
146+ end
147+ return p
153148end
154149
155150buf, repack, alias = SS.canonicalize(SS.Constants(), p)
0 commit comments