diff --git a/docs/src/example.md b/docs/src/example.md index 5a18074..b2cb625 100644 --- a/docs/src/example.md +++ b/docs/src/example.md @@ -9,30 +9,30 @@ using OrdinaryDiffEqTsit5 using LinearAlgebra mutable struct SubproblemParameters{P, Q, R} - p::P # tunable - q::Q - r::R + p::P # tunable + q::Q + r::R end mutable struct Parameters{P, C} - subparams::P - coeffs::C # tunable matrix + subparams::P + coeffs::C # tunable matrix end # the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams) # and `du[length(subparams)+1:end] .= coeffs * u` function rhs!(du, u, p::Parameters, t) - for (i, subpars) in enumerate(p.subparams) - du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t - end - N = length(p.subparams) - mul!(view(du, (N+1):(length(du))), p.coeffs, u) - return nothing + for (i, subpars) in enumerate(p.subparams) + du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t + end + N = length(p.subparams) + mul!(view(du, (N + 1):(length(du))), p.coeffs, u) + return nothing end u = sin.(0.1:0.1:1.0) subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5] -p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10])) +p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10])) tspan = (0.0, 1.0) prob = ODEProblem(rhs!, u, tspan, p) @@ -47,12 +47,13 @@ using SciMLSensitivity # 5 subparams[i].p, 50 elements in coeffs function simulate_with_tunables(tunables) - subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)] - coeffs = reshape(tunables[6:end], size(p.coeffs)) - newp = Parameters(subpars, coeffs) - newprob = remake(prob; p = newp) - sol = solve(newprob, Tsit5()) - return sum(sol.u[end]) + subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) + for (i, subpar) in enumerate(p.subparams)] + coeffs = reshape(tunables[6:end], size(p.coeffs)) + newp = Parameters(subpars, coeffs) + newprob = remake(prob; p = newp) + sol = solve(newprob, Tsit5()) + return sum(sol.u[end]) end ``` @@ -74,46 +75,49 @@ SS.ismutablescimlstructure(::Parameters) = true SS.hasportion(::SS.Tunable, ::Parameters) = true function SS.canonicalize(::SS.Tunable, p::Parameters) - # concatenate all tunable values into a single vector - buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs)) - - # repack takes a new vector of the same length as `buffer`, and constructs - # a new `Parameters` object using the values from the new vector for tunables - # and retaining old values for other parameters. This is exactly what replace does, - # so we can use that instead. - repack = let p = p - function repack(newbuffer) - SS.replace(SS.Tunable(), p, newbuffer) + # concatenate all tunable values into a single vector + buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs)) + + # repack takes a new vector of the same length as `buffer`, and constructs + # a new `Parameters` object using the values from the new vector for tunables + # and retaining old values for other parameters. This is exactly what replace does, + # so we can use that instead. + repack = let p = p + function repack(newbuffer) + SS.replace(SS.Tunable(), p, newbuffer) + end end - end - # the canonicalized vector, the repack function, and a boolean indicating - # whether the buffer aliases values in the parameter object (here, it doesn't) - return buffer, repack, false + # the canonicalized vector, the repack function, and a boolean indicating + # whether the buffer aliases values in the parameter object (here, it doesn't) + return buffer, repack, false end function SS.replace(::SS.Tunable, p::Parameters, newbuffer) - N = length(p.subparams) + length(p.coeffs) - @assert length(newbuffer) == N - subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)] - coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs)) - return Parameters(subparams, coeffs) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) + for (i, subpar) in enumerate(p.subparams)] + coeffs = reshape( + view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs)) + return Parameters(subparams, coeffs) end function SS.replace!(::SS.Tunable, p::Parameters, newbuffer) - N = length(p.subparams) + length(p.coeffs) - @assert length(newbuffer) == N - for (subpar, val) in zip(p.subparams, newbuffer) - subpar.p = val - end - copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer))) - return p + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + for (subpar, val) in zip(p.subparams, newbuffer) + subpar.p = val + end + copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer))) + return p end ``` Now, we should be able to differentiate through the ODE solve. ```@example basic_tutorial -Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) +Zygote.gradient( + simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) ``` We can also implement a `Constants` portion to store the rest of the values: @@ -122,29 +126,30 @@ We can also implement a `Constants` portion to store the rest of the values: SS.hasportion(::SS.Constants, ::Parameters) = true function SS.canonicalize(::SS.Constants, p::Parameters) - buffer = mapreduce(vcat, p.subparams) do subpar - [subpar.q, subpar.r] - end - repack = let p = p - function repack(newbuffer) - SS.replace(SS.Constants(), p, newbuffer) + buffer = mapreduce(vcat, p.subparams) do subpar + [subpar.q, subpar.r] + end + repack = let p = p + function repack(newbuffer) + SS.replace(SS.Constants(), p, newbuffer) + end end - end - return buffer, repack, false + return buffer, repack, false end function SS.replace(::SS.Constants, p::Parameters, newbuffer) - subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)] - return Parameters(subpars, p.coeffs) + subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i - 1], newbuffer[2i]) + for i in eachindex(p.subparams)] + return Parameters(subpars, p.coeffs) end function SS.replace!(::SS.Constants, p::Parameters, newbuffer) - for i in eachindex(p.subparams) - p.subparams[i].q = newbuffer[2i-1] - p.subparams[i].r = newbuffer[2i] - end - return p + for i in eachindex(p.subparams) + p.subparams[i].q = newbuffer[2i - 1] + p.subparams[i].r = newbuffer[2i] + end + return p end buf, repack, alias = SS.canonicalize(SS.Constants(), p)