@@ -9,25 +9,25 @@ using OrdinaryDiffEqTsit5
99using LinearAlgebra
1010
1111mutable struct SubproblemParameters{P, Q, R}
12- p::P # tunable
13- q::Q
14- r::R
12+ p::P # tunable
13+ q::Q
14+ r::R
1515end
1616
1717mutable struct Parameters{P, C}
18- subparams::P
19- coeffs::C # tunable matrix
18+ subparams::P
19+ coeffs::C # tunable matrix
2020end
2121
2222# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
2323# and `du[length(subparams)+1:end] .= coeffs * u`
2424function rhs!(du, u, p::Parameters, t)
25- for (i, subpars) in enumerate(p.subparams)
26- du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27- end
28- N = length(p.subparams)
29- mul!(view(du, (N+ 1):(length(du))), p.coeffs, u)
30- return nothing
25+ for (i, subpars) in enumerate(p.subparams)
26+ du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27+ end
28+ N = length(p.subparams)
29+ mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
30+ return nothing
3131end
3232
3333u = sin.(0.1:0.1:1.0)
@@ -47,12 +47,13 @@ using SciMLSensitivity
4747
4848# 5 subparams[i].p, 50 elements in coeffs
4949function simulate_with_tunables(tunables)
50- subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
51- coeffs = reshape(tunables[6:end], size(p.coeffs))
52- newp = Parameters(subpars, coeffs)
53- newprob = remake(prob; p = newp)
54- sol = solve(newprob, Tsit5())
55- return sum(sol.u[end])
50+ subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r)
51+ for (i, subpar) in enumerate(p.subparams)]
52+ coeffs = reshape(tunables[6:end], size(p.coeffs))
53+ newp = Parameters(subpars, coeffs)
54+ newprob = remake(prob; p = newp)
55+ sol = solve(newprob, Tsit5())
56+ return sum(sol.u[end])
5657end
5758```
5859
@@ -74,39 +75,40 @@ SS.ismutablescimlstructure(::Parameters) = true
7475SS.hasportion(::SS.Tunable, ::Parameters) = true
7576
7677function SS.canonicalize(::SS.Tunable, p::Parameters)
77- # concatenate all tunable values into a single vector
78- buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
79-
80- # repack takes a new vector of the same length as `buffer`, and constructs
81- # a new `Parameters` object using the values from the new vector for tunables
82- # and retaining old values for other parameters. This is exactly what replace does,
83- # so we can use that instead.
84- repack = let p = p
85- function repack(newbuffer)
86- SS.replace(SS.Tunable(), p, newbuffer)
78+ # concatenate all tunable values into a single vector
79+ buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
80+
81+ # repack takes a new vector of the same length as `buffer`, and constructs
82+ # a new `Parameters` object using the values from the new vector for tunables
83+ # and retaining old values for other parameters. This is exactly what replace does,
84+ # so we can use that instead.
85+ repack = let p = p
86+ function repack(newbuffer)
87+ SS.replace(SS.Tunable(), p, newbuffer)
88+ end
8789 end
88- end
89- # the canonicalized vector, the repack function, and a boolean indicating
90- # whether the buffer aliases values in the parameter object (here, it doesn't)
91- return buffer, repack, false
90+ # the canonicalized vector, the repack function, and a boolean indicating
91+ # whether the buffer aliases values in the parameter object (here, it doesn't)
92+ return buffer, repack, false
9293end
9394
9495function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
95- N = length(p.subparams) + length(p.coeffs)
96- @assert length(newbuffer) == N
97- subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
98- coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
99- return Parameters(subparams, coeffs)
96+ N = length(p.subparams) + length(p.coeffs)
97+ @assert length(newbuffer) == N
98+ subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
99+ for (i, subpar) in enumerate(p.subparams)]
100+ coeffs = reshape(view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
101+ return Parameters(subparams, coeffs)
100102end
101103
102104function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
103- N = length(p.subparams) + length(p.coeffs)
104- @assert length(newbuffer) == N
105- for (subpar, val) in zip(p.subparams, newbuffer)
106- subpar.p = val
107- end
108- copyto!(coeffs, view(newbuffer, (length(p.subparams)+ 1):length(newbuffer)))
109- return p
105+ N = length(p.subparams) + length(p.coeffs)
106+ @assert length(newbuffer) == N
107+ for (subpar, val) in zip(p.subparams, newbuffer)
108+ subpar.p = val
109+ end
110+ copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
111+ return p
110112end
111113```
112114
@@ -122,29 +124,30 @@ We can also implement a `Constants` portion to store the rest of the values:
122124SS.hasportion(::SS.Constants, ::Parameters) = true
123125
124126function SS.canonicalize(::SS.Constants, p::Parameters)
125- buffer = mapreduce(vcat, p.subparams) do subpar
126- [subpar.q, subpar.r]
127- end
128- repack = let p = p
129- function repack(newbuffer)
130- SS.replace(SS.Constants(), p, newbuffer)
127+ buffer = mapreduce(vcat, p.subparams) do subpar
128+ [subpar.q, subpar.r]
129+ end
130+ repack = let p = p
131+ function repack(newbuffer)
132+ SS.replace(SS.Constants(), p, newbuffer)
133+ end
131134 end
132- end
133135
134- return buffer, repack, false
136+ return buffer, repack, false
135137end
136138
137139function SS.replace(::SS.Constants, p::Parameters, newbuffer)
138- subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
139- return Parameters(subpars, p.coeffs)
140+ subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i - 1], newbuffer[2i])
141+ for i in eachindex(p.subparams)]
142+ return Parameters(subpars, p.coeffs)
140143end
141144
142145function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
143- for i in eachindex(p.subparams)
144- p.subparams[i].q = newbuffer[2i- 1]
145- p.subparams[i].r = newbuffer[2i]
146- end
147- return p
146+ for i in eachindex(p.subparams)
147+ p.subparams[i].q = newbuffer[2i - 1]
148+ p.subparams[i].r = newbuffer[2i]
149+ end
150+ return p
148151end
149152
150153buf, repack, alias = SS.canonicalize(SS.Constants(), p)
0 commit comments