Skip to content

Commit 3838c5e

Browse files
[ci-skip] Revert "Apply JuliaFormatter to fix code formatting"
This reverts commit 5af4c29. Reverting JuliaFormatter changes due to formatting issues being addressed in JuliaFormatter.jl PR#933.
1 parent 0ba4bd5 commit 3838c5e

File tree

1 file changed

+60
-65
lines changed

1 file changed

+60
-65
lines changed

docs/src/example.md

Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,30 @@ using OrdinaryDiffEqTsit5
99
using LinearAlgebra
1010
1111
mutable struct SubproblemParameters{P, Q, R}
12-
p::P # tunable
13-
q::Q
14-
r::R
12+
p::P # tunable
13+
q::Q
14+
r::R
1515
end
1616
1717
mutable struct Parameters{P, C}
18-
subparams::P
19-
coeffs::C # tunable matrix
18+
subparams::P
19+
coeffs::C # tunable matrix
2020
end
2121
2222
# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
2323
# and `du[length(subparams)+1:end] .= coeffs * u`
2424
function rhs!(du, u, p::Parameters, t)
25-
for (i, subpars) in enumerate(p.subparams)
26-
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27-
end
28-
N = length(p.subparams)
29-
mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
30-
return nothing
25+
for (i, subpars) in enumerate(p.subparams)
26+
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27+
end
28+
N = length(p.subparams)
29+
mul!(view(du, (N+1):(length(du))), p.coeffs, u)
30+
return nothing
3131
end
3232
3333
u = sin.(0.1:0.1:1.0)
3434
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
35-
p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10]))
35+
p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10]))
3636
tspan = (0.0, 1.0)
3737
3838
prob = ODEProblem(rhs!, u, tspan, p)
@@ -47,13 +47,12 @@ using SciMLSensitivity
4747
4848
# 5 subparams[i].p, 50 elements in coeffs
4949
function simulate_with_tunables(tunables)
50-
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r)
51-
for (i, subpar) in enumerate(p.subparams)]
52-
coeffs = reshape(tunables[6:end], size(p.coeffs))
53-
newp = Parameters(subpars, coeffs)
54-
newprob = remake(prob; p = newp)
55-
sol = solve(newprob, Tsit5())
56-
return sum(sol.u[end])
50+
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
51+
coeffs = reshape(tunables[6:end], size(p.coeffs))
52+
newp = Parameters(subpars, coeffs)
53+
newprob = remake(prob; p = newp)
54+
sol = solve(newprob, Tsit5())
55+
return sum(sol.u[end])
5756
end
5857
```
5958

@@ -75,49 +74,46 @@ SS.ismutablescimlstructure(::Parameters) = true
7574
SS.hasportion(::SS.Tunable, ::Parameters) = true
7675
7776
function SS.canonicalize(::SS.Tunable, p::Parameters)
78-
# concatenate all tunable values into a single vector
79-
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
80-
81-
# repack takes a new vector of the same length as `buffer`, and constructs
82-
# a new `Parameters` object using the values from the new vector for tunables
83-
# and retaining old values for other parameters. This is exactly what replace does,
84-
# so we can use that instead.
85-
repack = let p = p
86-
function repack(newbuffer)
87-
SS.replace(SS.Tunable(), p, newbuffer)
88-
end
77+
# concatenate all tunable values into a single vector
78+
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
79+
80+
# repack takes a new vector of the same length as `buffer`, and constructs
81+
# a new `Parameters` object using the values from the new vector for tunables
82+
# and retaining old values for other parameters. This is exactly what replace does,
83+
# so we can use that instead.
84+
repack = let p = p
85+
function repack(newbuffer)
86+
SS.replace(SS.Tunable(), p, newbuffer)
8987
end
90-
# the canonicalized vector, the repack function, and a boolean indicating
91-
# whether the buffer aliases values in the parameter object (here, it doesn't)
92-
return buffer, repack, false
88+
end
89+
# the canonicalized vector, the repack function, and a boolean indicating
90+
# whether the buffer aliases values in the parameter object (here, it doesn't)
91+
return buffer, repack, false
9392
end
9493
9594
function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
96-
N = length(p.subparams) + length(p.coeffs)
97-
@assert length(newbuffer) == N
98-
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
99-
for (i, subpar) in enumerate(p.subparams)]
100-
coeffs = reshape(
101-
view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
102-
return Parameters(subparams, coeffs)
95+
N = length(p.subparams) + length(p.coeffs)
96+
@assert length(newbuffer) == N
97+
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
98+
coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
99+
return Parameters(subparams, coeffs)
103100
end
104101
105102
function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
106-
N = length(p.subparams) + length(p.coeffs)
107-
@assert length(newbuffer) == N
108-
for (subpar, val) in zip(p.subparams, newbuffer)
109-
subpar.p = val
110-
end
111-
copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
112-
return p
103+
N = length(p.subparams) + length(p.coeffs)
104+
@assert length(newbuffer) == N
105+
for (subpar, val) in zip(p.subparams, newbuffer)
106+
subpar.p = val
107+
end
108+
copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer)))
109+
return p
113110
end
114111
```
115112

116113
Now, we should be able to differentiate through the ODE solve.
117114

118115
```@example basic_tutorial
119-
Zygote.gradient(
120-
simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
116+
Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
121117
```
122118

123119
We can also implement a `Constants` portion to store the rest of the values:
@@ -126,30 +122,29 @@ We can also implement a `Constants` portion to store the rest of the values:
126122
SS.hasportion(::SS.Constants, ::Parameters) = true
127123
128124
function SS.canonicalize(::SS.Constants, p::Parameters)
129-
buffer = mapreduce(vcat, p.subparams) do subpar
130-
[subpar.q, subpar.r]
131-
end
132-
repack = let p = p
133-
function repack(newbuffer)
134-
SS.replace(SS.Constants(), p, newbuffer)
135-
end
125+
buffer = mapreduce(vcat, p.subparams) do subpar
126+
[subpar.q, subpar.r]
127+
end
128+
repack = let p = p
129+
function repack(newbuffer)
130+
SS.replace(SS.Constants(), p, newbuffer)
136131
end
132+
end
137133
138-
return buffer, repack, false
134+
return buffer, repack, false
139135
end
140136
141137
function SS.replace(::SS.Constants, p::Parameters, newbuffer)
142-
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i - 1], newbuffer[2i])
143-
for i in eachindex(p.subparams)]
144-
return Parameters(subpars, p.coeffs)
138+
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
139+
return Parameters(subpars, p.coeffs)
145140
end
146141
147142
function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
148-
for i in eachindex(p.subparams)
149-
p.subparams[i].q = newbuffer[2i - 1]
150-
p.subparams[i].r = newbuffer[2i]
151-
end
152-
return p
143+
for i in eachindex(p.subparams)
144+
p.subparams[i].q = newbuffer[2i-1]
145+
p.subparams[i].r = newbuffer[2i]
146+
end
147+
return p
153148
end
154149
155150
buf, repack, alias = SS.canonicalize(SS.Constants(), p)

0 commit comments

Comments
 (0)