Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 61 additions & 58 deletions docs/src/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,25 @@ using OrdinaryDiffEqTsit5
using LinearAlgebra

mutable struct SubproblemParameters{P, Q, R}
p::P # tunable
q::Q
r::R
p::P # tunable
q::Q
r::R
end

mutable struct Parameters{P, C}
subparams::P
coeffs::C # tunable matrix
subparams::P
coeffs::C # tunable matrix
end

# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
# and `du[length(subparams)+1:end] .= coeffs * u`
function rhs!(du, u, p::Parameters, t)
for (i, subpars) in enumerate(p.subparams)
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
end
N = length(p.subparams)
mul!(view(du, (N+1):(length(du))), p.coeffs, u)
return nothing
for (i, subpars) in enumerate(p.subparams)
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
end
N = length(p.subparams)
mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
return nothing
end

u = sin.(0.1:0.1:1.0)
Expand All @@ -47,12 +47,13 @@ using SciMLSensitivity

# 5 subparams[i].p, 50 elements in coeffs
function simulate_with_tunables(tunables)
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(tunables[6:end], size(p.coeffs))
newp = Parameters(subpars, coeffs)
newprob = remake(prob; p = newp)
sol = solve(newprob, Tsit5())
return sum(sol.u[end])
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r)
for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(tunables[6:end], size(p.coeffs))
newp = Parameters(subpars, coeffs)
newprob = remake(prob; p = newp)
sol = solve(newprob, Tsit5())
return sum(sol.u[end])
end
```

Expand All @@ -74,39 +75,40 @@ SS.ismutablescimlstructure(::Parameters) = true
SS.hasportion(::SS.Tunable, ::Parameters) = true

function SS.canonicalize(::SS.Tunable, p::Parameters)
# concatenate all tunable values into a single vector
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))

# repack takes a new vector of the same length as `buffer`, and constructs
# a new `Parameters` object using the values from the new vector for tunables
# and retaining old values for other parameters. This is exactly what replace does,
# so we can use that instead.
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Tunable(), p, newbuffer)
# concatenate all tunable values into a single vector
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))

# repack takes a new vector of the same length as `buffer`, and constructs
# a new `Parameters` object using the values from the new vector for tunables
# and retaining old values for other parameters. This is exactly what replace does,
# so we can use that instead.
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Tunable(), p, newbuffer)
end
end
end
# the canonicalized vector, the repack function, and a boolean indicating
# whether the buffer aliases values in the parameter object (here, it doesn't)
return buffer, repack, false
# the canonicalized vector, the repack function, and a boolean indicating
# whether the buffer aliases values in the parameter object (here, it doesn't)
return buffer, repack, false
end

function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
return Parameters(subparams, coeffs)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
return Parameters(subparams, coeffs)
end

function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
for (subpar, val) in zip(p.subparams, newbuffer)
subpar.p = val
end
copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer)))
return p
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
for (subpar, val) in zip(p.subparams, newbuffer)
subpar.p = val
end
copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
return p
end
```

Expand All @@ -122,29 +124,30 @@ We can also implement a `Constants` portion to store the rest of the values:
SS.hasportion(::SS.Constants, ::Parameters) = true

function SS.canonicalize(::SS.Constants, p::Parameters)
buffer = mapreduce(vcat, p.subparams) do subpar
[subpar.q, subpar.r]
end
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Constants(), p, newbuffer)
buffer = mapreduce(vcat, p.subparams) do subpar
[subpar.q, subpar.r]
end
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Constants(), p, newbuffer)
end
end
end

return buffer, repack, false
return buffer, repack, false
end

function SS.replace(::SS.Constants, p::Parameters, newbuffer)
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
return Parameters(subpars, p.coeffs)
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i - 1], newbuffer[2i])
for i in eachindex(p.subparams)]
return Parameters(subpars, p.coeffs)
end

function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
for i in eachindex(p.subparams)
p.subparams[i].q = newbuffer[2i-1]
p.subparams[i].r = newbuffer[2i]
end
return p
for i in eachindex(p.subparams)
p.subparams[i].q = newbuffer[2i - 1]
p.subparams[i].r = newbuffer[2i]
end
return p
end

buf, repack, alias = SS.canonicalize(SS.Constants(), p)
Expand Down
Loading