@@ -9,30 +9,30 @@ using OrdinaryDiffEqTsit5
99using LinearAlgebra
1010
1111mutable struct SubproblemParameters{P, Q, R}
12- p::P # tunable
13- q::Q
14- r::R
12+ p::P # tunable
13+ q::Q
14+ r::R
1515end
1616
1717mutable struct Parameters{P, C}
18- subparams::P
19- coeffs::C # tunable matrix
18+ subparams::P
19+ coeffs::C # tunable matrix
2020end
2121
2222# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
2323# and `du[length(subparams)+1:end] .= coeffs * u`
2424function rhs!(du, u, p::Parameters, t)
25- for (i, subpars) in enumerate(p.subparams)
26- du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27- end
28- N = length(p.subparams)
29- mul!(view(du, (N+ 1):(length(du))), p.coeffs, u)
30- return nothing
25+ for (i, subpars) in enumerate(p.subparams)
26+ du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27+ end
28+ N = length(p.subparams)
29+ mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
30+ return nothing
3131end
3232
3333u = sin.(0.1:0.1:1.0)
3434subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
35- p = Parameters(subparams, cos.([0.1i+ 0.33j for i in 1:5, j in 1:10]))
35+ p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10]))
3636tspan = (0.0, 1.0)
3737
3838prob = ODEProblem(rhs!, u, tspan, p)
@@ -47,12 +47,13 @@ using SciMLSensitivity
4747
4848# 5 subparams[i].p, 50 elements in coeffs
4949function simulate_with_tunables(tunables)
50- subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
51- coeffs = reshape(tunables[6:end], size(p.coeffs))
52- newp = Parameters(subpars, coeffs)
53- newprob = remake(prob; p = newp)
54- sol = solve(newprob, Tsit5())
55- return sum(sol.u[end])
50+ subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r)
51+ for (i, subpar) in enumerate(p.subparams)]
52+ coeffs = reshape(tunables[6:end], size(p.coeffs))
53+ newp = Parameters(subpars, coeffs)
54+ newprob = remake(prob; p = newp)
55+ sol = solve(newprob, Tsit5())
56+ return sum(sol.u[end])
5657end
5758```
5859
@@ -74,46 +75,49 @@ SS.ismutablescimlstructure(::Parameters) = true
7475SS.hasportion(::SS.Tunable, ::Parameters) = true
7576
7677function SS.canonicalize(::SS.Tunable, p::Parameters)
77- # concatenate all tunable values into a single vector
78- buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
79-
80- # repack takes a new vector of the same length as `buffer`, and constructs
81- # a new `Parameters` object using the values from the new vector for tunables
82- # and retaining old values for other parameters. This is exactly what replace does,
83- # so we can use that instead.
84- repack = let p = p
85- function repack(newbuffer)
86- SS.replace(SS.Tunable(), p, newbuffer)
78+ # concatenate all tunable values into a single vector
79+ buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
80+
81+ # repack takes a new vector of the same length as `buffer`, and constructs
82+ # a new `Parameters` object using the values from the new vector for tunables
83+ # and retaining old values for other parameters. This is exactly what replace does,
84+ # so we can use that instead.
85+ repack = let p = p
86+ function repack(newbuffer)
87+ SS.replace(SS.Tunable(), p, newbuffer)
88+ end
8789 end
88- end
89- # the canonicalized vector, the repack function, and a boolean indicating
90- # whether the buffer aliases values in the parameter object (here, it doesn't)
91- return buffer, repack, false
90+ # the canonicalized vector, the repack function, and a boolean indicating
91+ # whether the buffer aliases values in the parameter object (here, it doesn't)
92+ return buffer, repack, false
9293end
9394
9495function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
95- N = length(p.subparams) + length(p.coeffs)
96- @assert length(newbuffer) == N
97- subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
98- coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
99- return Parameters(subparams, coeffs)
96+ N = length(p.subparams) + length(p.coeffs)
97+ @assert length(newbuffer) == N
98+ subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
99+ for (i, subpar) in enumerate(p.subparams)]
100+ coeffs = reshape(
101+ view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
102+ return Parameters(subparams, coeffs)
100103end
101104
102105function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
103- N = length(p.subparams) + length(p.coeffs)
104- @assert length(newbuffer) == N
105- for (subpar, val) in zip(p.subparams, newbuffer)
106- subpar.p = val
107- end
108- copyto!(coeffs, view(newbuffer, (length(p.subparams)+ 1):length(newbuffer)))
109- return p
106+ N = length(p.subparams) + length(p.coeffs)
107+ @assert length(newbuffer) == N
108+ for (subpar, val) in zip(p.subparams, newbuffer)
109+ subpar.p = val
110+ end
111+ copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
112+ return p
110113end
111114```
112115
113116Now, we should be able to differentiate through the ODE solve.
114117
115118``` @example basic_tutorial
116- Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
119+ Zygote.gradient(
120+ simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
117121```
118122
119123We can also implement a ` Constants ` portion to store the rest of the values:
@@ -122,29 +126,30 @@ We can also implement a `Constants` portion to store the rest of the values:
122126SS.hasportion(::SS.Constants, ::Parameters) = true
123127
124128function SS.canonicalize(::SS.Constants, p::Parameters)
125- buffer = mapreduce(vcat, p.subparams) do subpar
126- [subpar.q, subpar.r]
127- end
128- repack = let p = p
129- function repack(newbuffer)
130- SS.replace(SS.Constants(), p, newbuffer)
129+ buffer = mapreduce(vcat, p.subparams) do subpar
130+ [subpar.q, subpar.r]
131+ end
132+ repack = let p = p
133+ function repack(newbuffer)
134+ SS.replace(SS.Constants(), p, newbuffer)
135+ end
131136 end
132- end
133137
134- return buffer, repack, false
138+ return buffer, repack, false
135139end
136140
137141function SS.replace(::SS.Constants, p::Parameters, newbuffer)
138- subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
139- return Parameters(subpars, p.coeffs)
142+ subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i - 1], newbuffer[2i])
143+ for i in eachindex(p.subparams)]
144+ return Parameters(subpars, p.coeffs)
140145end
141146
142147function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
143- for i in eachindex(p.subparams)
144- p.subparams[i].q = newbuffer[2i- 1]
145- p.subparams[i].r = newbuffer[2i]
146- end
147- return p
148+ for i in eachindex(p.subparams)
149+ p.subparams[i].q = newbuffer[2i - 1]
150+ p.subparams[i].r = newbuffer[2i]
151+ end
152+ return p
148153end
149154
150155buf, repack, alias = SS.canonicalize(SS.Constants(), p)
0 commit comments