Skip to content

Commit 9e03874

Browse files
authored
Merge pull request #1724 from SciML/myb/partial_jac
Use efficient partial Jacobian evaluation in `linearization_function`
2 parents 3212f1a + f61434c commit 9e03874

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

src/systems/abstractsystem.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,28 +1019,35 @@ function linearization_function(sys::AbstractSystem, inputs,
10191019
kwargs...)
10201020
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
10211021
kwargs...)
1022-
sts = states(sys)
1023-
fun = ODEFunction(sys)
1024-
lin_fun = let fun = fun,
1025-
h = ModelingToolkit.build_explicit_observed_function(sys, outputs)
1022+
lin_fun = let diff_idxs = diff_idxs,
1023+
alge_idxs = alge_idxs,
1024+
input_idxs = input_idxs,
1025+
sts = states(sys),
1026+
fun = ODEFunction(sys),
1027+
h = ModelingToolkit.build_explicit_observed_function(sys, outputs),
1028+
chunk = ForwardDiff.Chunk(input_idxs)
10261029

10271030
function (u, p, t)
10281031
if u !== nothing # Handle systems without states
10291032
length(sts) == length(u) ||
10301033
error("Number of state variables ($(length(sts))) does not match the number of input states ($(length(u)))")
10311034
uf = SciMLBase.UJacobianWrapper(fun, t, p)
10321035
fg_xz = ForwardDiff.jacobian(uf, u)
1033-
h_xz = ForwardDiff.jacobian(xz -> h(xz, p, t), u)
1036+
h_xz = ForwardDiff.jacobian(let p = p, t = t
1037+
xz -> h(xz, p, t)
1038+
end, u)
10341039
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
1035-
# TODO: this is very inefficient, p contains all parameters of the system
1036-
fg_u = ForwardDiff.jacobian(pf, p)[:, input_idxs]
1040+
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
10371041
else
10381042
length(sts) == 0 ||
10391043
error("Number of state variables (0) does not match the number of input states ($(length(u)))")
10401044
fg_xz = zeros(0, 0)
10411045
h_xz = fg_u = zeros(0, length(inputs))
10421046
end
1043-
h_u = ForwardDiff.jacobian(p -> h(u, p, t), p)[:, input_idxs]
1047+
hp = let u = u, t = t
1048+
p -> h(u, p, t)
1049+
end
1050+
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
10441051
(f_x = fg_xz[diff_idxs, diff_idxs],
10451052
f_z = fg_xz[diff_idxs, alge_idxs],
10461053
g_x = fg_xz[alge_idxs, diff_idxs],

src/utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,3 +710,24 @@ function Base.iterate(it::StatefulBFS, queue = (eltype(it)[(0, it.t)]))
710710
end
711711
return (lv, t), queue
712712
end
713+
714+
function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
715+
E = eltype(p)
716+
tag = ForwardDiff.Tag(pf, E)
717+
T = typeof(tag)
718+
dualtype = ForwardDiff.Dual{T, E, ForwardDiff.chunksize(chunk)}
719+
p_big = similar(p, dualtype)
720+
copyto!(p_big, p)
721+
p_closure = let pf = pf,
722+
input_idxs = input_idxs,
723+
p_big = p_big
724+
725+
function (p_small_inner)
726+
p_big[input_idxs] .= p_small_inner
727+
pf(p_big)
728+
end
729+
end
730+
p_small = p[input_idxs]
731+
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
732+
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
733+
end

0 commit comments

Comments
 (0)