Skip to content

Commit 2bbb416

Browse files
committed
wip: simplifying allocations in DiffSolvers
1 parent 8048dbe commit 2bbb416

File tree

7 files changed

+120
-107
lines changed

7 files changed

+120
-107
lines changed

src/estimator/execute.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function remove_op!(estim::StateEstimator, ym, d, u=nothing)
1818
end
1919

2020
@doc raw"""
21-
f̂!(x̂next0, û0, estim::StateEstimator, model::SimModel, x̂0, u0, d0) -> nothing
21+
f̂!(x̂0next, û0, x0i, estim::StateEstimator, model::SimModel, x̂0, u0, d0) -> nothing
2222
2323
Mutating state function ``\mathbf{f̂}`` of the augmented model.
2424
@@ -30,39 +30,40 @@ function returns the next state of the augmented model, defined as:
3030
\mathbf{ŷ_0}(k) &= \mathbf{ĥ}\Big(\mathbf{x̂_0}(k), \mathbf{d_0}(k)\Big)
3131
\end{aligned}
3232
```
33-
where ``\mathbf{x̂_0}(k+1)`` is stored in `x̂next0` argument. The method mutates `x̂next0` and
34-
`û0` in place, the latter stores the input vector of the augmented model
35-
``\mathbf{u_0 + ŷ_{s_u}}``. The model parameter vector `model.p` is not included in the
36-
function signature for conciseness.
33+
where ``\mathbf{x̂_0}(k+1)`` is stored in `x̂0next` argument. The method mutates `x̂0next`,
34+
`û0` and `x0i` in place. The argument `û0` is the input vector of the augmented model,
35+
computed by ``\mathbf{û_0 = u_0 + ŷ_{s_u}}``. The argument `x0i` is used to store the
36+
intermediate stage values of `model.solver` (when applicable). The model parameter vector
37+
`model.p` is not included in the function signature for conciseness.
3738
"""
38-
function f̂!(x̂next0, û0, estim::StateEstimator, model::SimModel, x̂0, u0, d0)
39-
return f̂!(x̂next0, û0, model, estim.As, estim.Cs_u, x̂0, u0, d0)
39+
function f̂!(x̂0next, û0, x0i, estim::StateEstimator, model::SimModel, x̂0, u0, d0)
40+
return f̂!(x̂0next, û0, x0i, model, estim.As, estim.Cs_u, x̂0, u0, d0)
4041
end
4142

4243
"""
43-
f̂!(x̂next0, _ , estim::StateEstimator, model::LinModel, x̂0, u0, d0) -> nothing
44+
f̂!(x̂0next, _ , _ , estim::StateEstimator, model::LinModel, x̂0, u0, d0) -> nothing
4445
4546
Use the augmented model matrices if `model` is a [`LinModel`](@ref).
4647
"""
47-
function f̂!(x̂next0, _ , estim::StateEstimator, ::LinModel, x̂0, u0, d0)
48-
mul!(x̂next0, estim.Â, x̂0)
49-
mul!(x̂next0, estim.B̂u, u0, 1, 1)
50-
mul!(x̂next0, estim.B̂d, d0, 1, 1)
48+
function f̂!(x̂0next, _ , _ , estim::StateEstimator, ::LinModel, x̂0, u0, d0)
49+
mul!(x̂0next, estim.Â, x̂0)
50+
mul!(x̂0next, estim.B̂u, u0, 1, 1)
51+
mul!(x̂0next, estim.B̂d, d0, 1, 1)
5152
return nothing
5253
end
5354

5455
"""
55-
f̂!(x̂next0, û0, model::SimModel, As, Cs_u, x̂0, u0, d0)
56+
f̂!(x̂0next, û0, x0i, model::SimModel, As, Cs_u, x̂0, u0, d0)
5657
5758
Same than [`f̂!`](@ref) for [`SimModel`](@ref) but without the `estim` argument.
5859
"""
59-
function f̂!(x̂next0, û0, model::SimModel, As, Cs_u, x̂0, u0, d0)
60+
function f̂!(x̂0next, û0, x0i, model::SimModel, As, Cs_u, x̂0, u0, d0)
6061
# `@views` macro avoid copies with matrix slice operator e.g. [a:b]
6162
@views x̂d, x̂s = x̂0[1:model.nx], x̂0[model.nx+1:end]
62-
@views x̂d_next, x̂s_next = x̂next0[1:model.nx], x̂next0[model.nx+1:end]
63-
mul!(û0, Cs_u, x̂s)
63+
@views x̂d_next, x̂s_next = x̂0next[1:model.nx], x̂0next[model.nx+1:end]
64+
mul!(û0, Cs_u, x̂s) # ŷs_u = Cs_u * x̂s
6465
û0 .+= u0
65-
f!(x̂d_next, model, x̂d, û0, d0, model.p)
66+
f!(x̂d_next, x0i, model, x̂d, û0, d0, model.p)
6667
mul!(x̂s_next, As, x̂s)
6768
return nothing
6869
end
@@ -96,7 +97,7 @@ function ĥ!(ŷ0, model::SimModel, Cs_y, x̂0, d0)
9697
# `@views` macro avoid copies with matrix slice operator e.g. [a:b]
9798
@views x̂d, x̂s = x̂0[1:model.nx], x̂0[model.nx+1:end]
9899
h!(ŷ0, model, x̂d, d0, model.p)
99-
mul!(ŷ0, Cs_y, x̂s, 1, 1)
100+
mul!(ŷ0, Cs_y, x̂s, 1, 1) # ŷ0 = ŷ0 + Cs_y*x̂s
100101
return nothing
101102
end
102103

src/estimator/internal_model.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,25 @@ function matrices_internalmodel(model::SimModel{NT}) where NT<:Real
168168
end
169169

170170
@doc raw"""
171-
f̂!(x̂next0, _ , estim::InternalModel, model::NonLinModel, x̂0, u0, d0)
171+
f̂!(x̂0next, _ , x̂0i, estim::InternalModel, model::NonLinModel, x̂0, u0, d0)
172172
173173
State function ``\mathbf{f̂}`` of [`InternalModel`](@ref) for [`NonLinModel`](@ref).
174174
175-
It calls `model.f!(x̂next0, x̂0, u0 ,d0, model.p)` since this estimator does not augment the states.
175+
It calls `model.solver_f!(x̂0next, x̂0i, x̂0, u0 ,d0, model.p)` directly since this estimator
176+
does not augment the states.
176177
"""
177-
f̂!(x̂next0, _, ::InternalModel, model::NonLinModel, x̂0, u0, d0) = model.f!(x̂next0, x̂0, u0, d0, model.p)
178+
function f̂!(x̂0next, _ , x̂0i, ::InternalModel, model::NonLinModel, x̂0, u0, d0)
179+
return model.solver_f!(x̂0next, x̂0i, x̂0, u0, d0, model.p)
180+
end
178181

179182
@doc raw"""
180183
ĥ!(ŷ0, estim::InternalModel, model::NonLinModel, x̂0, d0)
181184
182-
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls `model.h!`.
185+
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls `model.solver_h!`.
183186
"""
184-
ĥ!(x̂next0, ::InternalModel, model::NonLinModel, x̂0, d0) = model.h!(x̂next0, x̂0, d0, model.p)
187+
function ĥ!(ŷ0, ::InternalModel, model::NonLinModel, x̂0, d0)
188+
return model.solver_h!(ŷ0, x̂0, d0, model.p)
189+
end
185190

186191

187192
@doc raw"""

src/model/linearization.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
2-
get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p, backend) -> linfunc!
2+
get_linearization_func(
3+
NT, solver_f!, solver_h!, nu, nx, ny, nd, ns, p, solver, backend
4+
) -> linfunc!
35
4-
Return the `linfunc!` function that computes the Jacobians of `f!` and `h!` functions.
6+
Return `linfunc!` function that computes Jacobians of `solver_f!` and `solver_h!` functions.
57
68
The function has the following signature:
79
```
@@ -11,31 +13,33 @@ and it should modifies in-place all the arguments before `backend`. The `backend
1113
is an `AbstractADType` object from `DifferentiationInterface`. The `cst_x`, `cst_u` and
1214
`cst_d` are `DifferentiationInterface.Constant` objects with the linearization points.
1315
"""
14-
function get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p, backend)
15-
f_x!(xnext, x, u, d) = f!(xnext, x, u, d, p)
16-
f_u!(xnext, u, x, d) = f!(xnext, x, u, d, p)
17-
f_d!(xnext, d, x, u) = f!(xnext, x, u, d, p)
18-
h_x!(y, x, d) = h!(y, x, d, p)
19-
h_d!(y, d, x) = h!(y, x, d, p)
16+
function get_linearization_func(NT, solver_f!, solver_h!, nu, nx, ny, nd, p, solver, backend)
17+
f_x!(xnext, x, xi, u, d) = solver_f!(xnext, xi, x, u, d, p)
18+
f_u!(xnext, u, xi, x, d) = solver_f!(xnext, xi, x, u, d, p)
19+
f_d!(xnext, d, xi, x, u) = solver_f!(xnext, xi, x, u, d, p)
20+
h_x!(y, x, d) = solver_h!(y, x, d, p)
21+
h_d!(y, d, x) = solver_h!(y, x, d, p)
2022
strict = Val(true)
2123
xnext = zeros(NT, nx)
2224
y = zeros(NT, ny)
2325
x = zeros(NT, nx)
2426
u = zeros(NT, nu)
2527
d = zeros(NT, nd)
28+
xi = zeros(NT, nx*(solver.ni+1))
29+
cache_xi = Cache(xi)
2630
cst_x = Constant(x)
2731
cst_u = Constant(u)
2832
cst_d = Constant(d)
29-
A_prep = prepare_jacobian(f_x!, xnext, backend, x, cst_u, cst_d; strict)
30-
Bu_prep = prepare_jacobian(f_u!, xnext, backend, u, cst_x, cst_d; strict)
31-
Bd_prep = prepare_jacobian(f_d!, xnext, backend, d, cst_x, cst_u; strict)
32-
C_prep = prepare_jacobian(h_x!, y, backend, x, cst_d ; strict)
33-
Dd_prep = prepare_jacobian(h_d!, y, backend, d, cst_x ; strict)
33+
A_prep = prepare_jacobian(f_x!, xnext, backend, x, cache_xi, cst_u, cst_d; strict)
34+
Bu_prep = prepare_jacobian(f_u!, xnext, backend, u, cache_xi, cst_x, cst_d; strict)
35+
Bd_prep = prepare_jacobian(f_d!, xnext, backend, d, cache_xi, cst_x, cst_u; strict)
36+
C_prep = prepare_jacobian(h_x!, y, backend, x, cst_d ; strict)
37+
Dd_prep = prepare_jacobian(h_d!, y, backend, d, cst_x ; strict)
3438
function linfunc!(xnext, y, A, Bu, C, Bd, Dd, backend, x, u, d, cst_x, cst_u, cst_d)
3539
# all the arguments before `backend` are mutated in this function
36-
jacobian!(f_x!, xnext, A, A_prep, backend, x, cst_u, cst_d)
37-
jacobian!(f_u!, xnext, Bu, Bu_prep, backend, u, cst_x, cst_d)
38-
jacobian!(f_d!, xnext, Bd, Bd_prep, backend, d, cst_x, cst_u)
40+
jacobian!(f_x!, xnext, A, A_prep, backend, x, cache_xi, cst_u, cst_d)
41+
jacobian!(f_u!, xnext, Bu, Bu_prep, backend, u, cache_xi, cst_x, cst_d)
42+
jacobian!(f_d!, xnext, Bd, Bd_prep, backend, d, cache_xi, cst_x, cst_u)
3943
jacobian!(h_x!, y, C, C_prep, backend, x, cst_d)
4044
jacobian!(h_d!, y, Dd, Dd_prep, backend, d, cst_x)
4145
return nothing
@@ -154,21 +158,21 @@ function linearize!(
154158
nonlinmodel = model
155159
buffer = nonlinmodel.buffer
156160
# --- remove the operating points of the nonlinear model (typically zeros) ---
157-
x0, u0, d0 = buffer.x, buffer.u, buffer.d
161+
x0, u0, d0, x0i = buffer.x, buffer.u, buffer.d, buffer.xi
158162
x0 .= x .- nonlinmodel.xop
159163
u0 .= u .- nonlinmodel.uop
160164
d0 .= d .- nonlinmodel.dop
161165
# --- compute the Jacobians at linearization points ---
162166
linearize_core!(linmodel, nonlinmodel, x0, u0, d0)
163167
# --- compute the nonlinear model output at operating points ---
164-
xnext0, y0 = linmodel.buffer.x, linmodel.buffer.y
168+
x0next, y0 = linmodel.buffer.x, linmodel.buffer.y
165169
h!(y0, nonlinmodel, x0, d0, model.p)
166170
y = y0
167171
y .= y0 .+ nonlinmodel.yop
168172
# --- compute the nonlinear model next state at operating points ---
169-
f!(xnext0, nonlinmodel, x0, u0, d0, model.p)
170-
xnext = xnext0
171-
xnext .= xnext0 .+ nonlinmodel.fop .- nonlinmodel.xop
173+
f!(x0next, x0i, nonlinmodel, x0, u0, d0, model.p)
174+
xnext = x0next
175+
xnext .= x0next .+ nonlinmodel.fop .- nonlinmodel.xop
172176
# --- modify the linear model operating points ---
173177
linmodel.uop .= u
174178
linmodel.yop .= y
@@ -182,13 +186,13 @@ end
182186

183187
"Call `linfunc!` function to compute the Jacobians of `model` at the linearization point."
184188
function linearize_core!(linmodel::LinModel, model::SimModel, x0, u0, d0)
185-
xnext0, y0 = linmodel.buffer.x, linmodel.buffer.y
189+
x0next, y0 = linmodel.buffer.x, linmodel.buffer.y
186190
A, Bu, C, Bd, Dd = linmodel.A, linmodel.Bu, linmodel.C, linmodel.Bd, linmodel.Dd
187191
cst_x = Constant(x0)
188192
cst_u = Constant(u0)
189193
cst_d = Constant(d0)
190194
backend = model.jacobian
191-
model.linfunc!(xnext0, y0, A, Bu, C, Bd, Dd, backend, x0, u0, d0, cst_x, cst_u, cst_d)
195+
model.linfunc!(x0next, y0, A, Bu, C, Bd, Dd, backend, x0, u0, d0, cst_x, cst_u, cst_d)
192196
return nothing
193197
end
194198

src/model/linmodel.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,20 +258,20 @@ function steadystate!(model::LinModel, u0, d0)
258258
end
259259

260260
"""
261-
f!(xnext0, model::LinModel, x0, u0, d0, p) -> nothing
261+
f!(x0next, _ , model::LinModel, x0, u0, d0, _ ) -> nothing
262262
263-
Evaluate `xnext0 = A*x0 + Bu*u0 + Bd*d0` in-place when `model` is a [`LinModel`](@ref).
263+
Evaluate `x0next = A*x0 + Bu*u0 + Bd*d0` in-place when `model` is a [`LinModel`](@ref).
264264
"""
265-
function f!(xnext0, model::LinModel, x0, u0, d0, _ )
266-
mul!(xnext0, model.A, x0)
267-
mul!(xnext0, model.Bu, u0, 1, 1)
268-
mul!(xnext0, model.Bd, d0, 1, 1)
265+
function f!(x0next, _ , model::LinModel, x0, u0, d0, _ )
266+
mul!(x0next, model.A, x0)
267+
mul!(x0next, model.Bu, u0, 1, 1)
268+
mul!(x0next, model.Bd, d0, 1, 1)
269269
return nothing
270270
end
271271

272272

273273
"""
274-
h!(y0, model::LinModel, x0, d0, p) -> nothing
274+
h!(y0, model::LinModel, x0, d0, _ ) -> nothing
275275
276276
Evaluate `y0 = C*x0 + Dd*d0` in-place when `model` is a [`LinModel`](@ref).
277277
"""

src/model/nonlinmodel.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ struct NonLinModel{
88
LF<:Function
99
} <: SimModel{NT}
1010
x0::Vector{NT}
11-
f!::F
12-
h!::H
11+
solver_f!::F
12+
solver_h!::H
1313
p::PT
1414
solver::DS
1515
Ts::NT
@@ -31,7 +31,8 @@ struct NonLinModel{
3131
linfunc!::LF
3232
buffer::SimModelBuffer{NT}
3333
function NonLinModel{NT}(
34-
f!::F, h!::H, Ts, nu, nx, ny, nd, p::PT, solver::DS, jacobian::JB, linfunc!::LF
34+
solver_f!::F, solver_h!::H, Ts, nu, nx, ny, nd,
35+
p::PT, solver::DS, jacobian::JB, linfunc!::LF
3536
) where {
3637
NT<:Real,
3738
F<:Function,
@@ -53,10 +54,10 @@ struct NonLinModel{
5354
xname = ["\$x_{$i}\$" for i in 1:nx]
5455
x0 = zeros(NT, nx)
5556
t = zeros(NT, 1)
56-
buffer = SimModelBuffer{NT}(nu, nx, ny, nd, solver.ns)
57+
buffer = SimModelBuffer{NT}(nu, nx, ny, nd, solver.ni)
5758
return new{NT, F, H, PT, DS, JB, LF}(
5859
x0,
59-
f!, h!,
60+
solver_f!, solver_h!,
6061
p,
6162
solver,
6263
Ts, t,
@@ -172,9 +173,13 @@ function NonLinModel{NT}(
172173
) where {NT<:Real}
173174
isnothing(solver) && (solver=EmptySolver())
174175
f!, h! = get_mutating_functions(NT, f, h)
175-
f!, h! = get_solver_functions(NT, solver, f!, h!, Ts, nu, nx, ny, nd)
176-
linfunc! = get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p, jacobian)
177-
return NonLinModel{NT}(f!, h!, Ts, nu, nx, ny, nd, p, solver, jacobian, linfunc!)
176+
solver_f!, solver_h! = get_solver_functions(NT, solver, f!, h!, Ts, nu, nx, ny, nd)
177+
linfunc! = get_linearization_func(
178+
NT, solver_f!, solver_h!, nu, nx, ny, nd, p, solver, jacobian
179+
)
180+
return NonLinModel{NT}(
181+
solver_f!, solver_h!, Ts, nu, nx, ny, nd, p, solver, jacobian, linfunc!
182+
)
178183
end
179184

180185
function NonLinModel(
@@ -262,11 +267,11 @@ Call [`linearize(model; x, u, d)`](@ref) and return the resulting linear model.
262267
"""
263268
LinModel(model::NonLinModel; kwargs...) = linearize(model; kwargs...)
264269

265-
"Call `model.f!(xnext0, x0, u0, d0, p)` for [`NonLinModel`](@ref)."
266-
f!(xnext0, model::NonLinModel, x0, u0, d0, p) = model.f!(xnext0, model.buffer.K, x0, u0, d0, p)
270+
"Call `model.solver_f!(x0next, x0i, x0, u0, d0, p)` for [`NonLinModel`](@ref)."
271+
f!(x0next, x0i, model::NonLinModel, x0, u0, d0, p) = model.solver_f!(x0next, x0i, x0, u0, d0, p)
267272

268-
"Call `model.h!(y0, x0, d0, p)` for [`NonLinModel`](@ref)."
269-
h!(y0, model::NonLinModel, x0, d0, p) = model.h!(y0, x0, d0, p)
273+
"Call `model.solver_h!(y0, x0, d0, p)` for [`NonLinModel`](@ref)."
274+
h!(y0, model::NonLinModel, x0, d0, p) = model.solver_h!(y0, x0, d0, p)
270275

271276
detailstr(model::NonLinModel) = ", $(typeof(model.solver).name.name)($(model.solver.order)) solver"
272277
detailstr(::NonLinModel{<:Real, <:Function, <:Function, <:Any, <:EmptySolver}) = ", empty solver"

0 commit comments

Comments
 (0)