Skip to content

Commit a582542

Browse files
Merge pull request #35 from ChrisRackauckas/fix-formatting
Apply JuliaFormatter to fix code formatting
2 parents 3bce3d3 + 5af4c29 commit a582542

File tree

1 file changed

+65
-60
lines changed

1 file changed

+65
-60
lines changed

docs/src/example.md

Lines changed: 65 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,30 @@ using OrdinaryDiffEqTsit5
99
using LinearAlgebra
1010
1111
mutable struct SubproblemParameters{P, Q, R}
12-
p::P # tunable
13-
q::Q
14-
r::R
12+
p::P # tunable
13+
q::Q
14+
r::R
1515
end
1616
1717
mutable struct Parameters{P, C}
18-
subparams::P
19-
coeffs::C # tunable matrix
18+
subparams::P
19+
coeffs::C # tunable matrix
2020
end
2121
2222
# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
2323
# and `du[length(subparams)+1:end] .= coeffs * u`
2424
function rhs!(du, u, p::Parameters, t)
25-
for (i, subpars) in enumerate(p.subparams)
26-
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27-
end
28-
N = length(p.subparams)
29-
mul!(view(du, (N+1):(length(du))), p.coeffs, u)
30-
return nothing
25+
for (i, subpars) in enumerate(p.subparams)
26+
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27+
end
28+
N = length(p.subparams)
29+
mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
30+
return nothing
3131
end
3232
3333
u = sin.(0.1:0.1:1.0)
3434
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
35-
p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10]))
35+
p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10]))
3636
tspan = (0.0, 1.0)
3737
3838
prob = ODEProblem(rhs!, u, tspan, p)
@@ -47,12 +47,13 @@ using SciMLSensitivity
4747
4848
# 5 subparams[i].p, 50 elements in coeffs
4949
function simulate_with_tunables(tunables)
50-
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
51-
coeffs = reshape(tunables[6:end], size(p.coeffs))
52-
newp = Parameters(subpars, coeffs)
53-
newprob = remake(prob; p = newp)
54-
sol = solve(newprob, Tsit5())
55-
return sum(sol.u[end])
50+
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r)
51+
for (i, subpar) in enumerate(p.subparams)]
52+
coeffs = reshape(tunables[6:end], size(p.coeffs))
53+
newp = Parameters(subpars, coeffs)
54+
newprob = remake(prob; p = newp)
55+
sol = solve(newprob, Tsit5())
56+
return sum(sol.u[end])
5657
end
5758
```
5859

@@ -74,46 +75,49 @@ SS.ismutablescimlstructure(::Parameters) = true
7475
SS.hasportion(::SS.Tunable, ::Parameters) = true
7576
7677
function SS.canonicalize(::SS.Tunable, p::Parameters)
77-
# concatenate all tunable values into a single vector
78-
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
79-
80-
# repack takes a new vector of the same length as `buffer`, and constructs
81-
# a new `Parameters` object using the values from the new vector for tunables
82-
# and retaining old values for other parameters. This is exactly what replace does,
83-
# so we can use that instead.
84-
repack = let p = p
85-
function repack(newbuffer)
86-
SS.replace(SS.Tunable(), p, newbuffer)
78+
# concatenate all tunable values into a single vector
79+
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
80+
81+
# repack takes a new vector of the same length as `buffer`, and constructs
82+
# a new `Parameters` object using the values from the new vector for tunables
83+
# and retaining old values for other parameters. This is exactly what replace does,
84+
# so we can use that instead.
85+
repack = let p = p
86+
function repack(newbuffer)
87+
SS.replace(SS.Tunable(), p, newbuffer)
88+
end
8789
end
88-
end
89-
# the canonicalized vector, the repack function, and a boolean indicating
90-
# whether the buffer aliases values in the parameter object (here, it doesn't)
91-
return buffer, repack, false
90+
# the canonicalized vector, the repack function, and a boolean indicating
91+
# whether the buffer aliases values in the parameter object (here, it doesn't)
92+
return buffer, repack, false
9293
end
9394
9495
function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
95-
N = length(p.subparams) + length(p.coeffs)
96-
@assert length(newbuffer) == N
97-
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
98-
coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
99-
return Parameters(subparams, coeffs)
96+
N = length(p.subparams) + length(p.coeffs)
97+
@assert length(newbuffer) == N
98+
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
99+
for (i, subpar) in enumerate(p.subparams)]
100+
coeffs = reshape(
101+
view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
102+
return Parameters(subparams, coeffs)
100103
end
101104
102105
function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
103-
N = length(p.subparams) + length(p.coeffs)
104-
@assert length(newbuffer) == N
105-
for (subpar, val) in zip(p.subparams, newbuffer)
106-
subpar.p = val
107-
end
108-
copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer)))
109-
return p
106+
N = length(p.subparams) + length(p.coeffs)
107+
@assert length(newbuffer) == N
108+
for (subpar, val) in zip(p.subparams, newbuffer)
109+
subpar.p = val
110+
end
111+
copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
112+
return p
110113
end
111114
```
112115

113116
Now, we should be able to differentiate through the ODE solve.
114117

115118
```@example basic_tutorial
116-
Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
119+
Zygote.gradient(
120+
simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
117121
```
118122

119123
We can also implement a `Constants` portion to store the rest of the values:
@@ -122,29 +126,30 @@ We can also implement a `Constants` portion to store the rest of the values:
122126
SS.hasportion(::SS.Constants, ::Parameters) = true
123127
124128
function SS.canonicalize(::SS.Constants, p::Parameters)
125-
buffer = mapreduce(vcat, p.subparams) do subpar
126-
[subpar.q, subpar.r]
127-
end
128-
repack = let p = p
129-
function repack(newbuffer)
130-
SS.replace(SS.Constants(), p, newbuffer)
129+
buffer = mapreduce(vcat, p.subparams) do subpar
130+
[subpar.q, subpar.r]
131+
end
132+
repack = let p = p
133+
function repack(newbuffer)
134+
SS.replace(SS.Constants(), p, newbuffer)
135+
end
131136
end
132-
end
133137
134-
return buffer, repack, false
138+
return buffer, repack, false
135139
end
136140
137141
function SS.replace(::SS.Constants, p::Parameters, newbuffer)
138-
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
139-
return Parameters(subpars, p.coeffs)
142+
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i - 1], newbuffer[2i])
143+
for i in eachindex(p.subparams)]
144+
return Parameters(subpars, p.coeffs)
140145
end
141146
142147
function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
143-
for i in eachindex(p.subparams)
144-
p.subparams[i].q = newbuffer[2i-1]
145-
p.subparams[i].r = newbuffer[2i]
146-
end
147-
return p
148+
for i in eachindex(p.subparams)
149+
p.subparams[i].q = newbuffer[2i - 1]
150+
p.subparams[i].r = newbuffer[2i]
151+
end
152+
return p
148153
end
149154
150155
buf, repack, alias = SS.canonicalize(SS.Constants(), p)

0 commit comments

Comments
 (0)