Skip to content

Commit 247160d

Browse files
feat: use DI to calculate jacobians in LinearizationFunction
1 parent f2d00b0 commit 247160d

File tree

3 files changed

+107
-20
lines changed

3 files changed

+107
-20
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <accounts@chrisr
44
version = "9.68.1"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -16,6 +17,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1617
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1718
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
1819
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
20+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1921
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2022
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2123
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -77,6 +79,7 @@ MTKInfiniteOptExt = "InfiniteOpt"
7779
MTKLabelledArraysExt = "LabelledArrays"
7880

7981
[compat]
82+
ADTypes = "1.14.0"
8083
AbstractTrees = "0.3, 0.4"
8184
ArrayInterface = "6, 7"
8285
BifurcationKit = "0.4"
@@ -96,6 +99,7 @@ DiffEqBase = "6.165.1"
9699
DiffEqCallbacks = "2.16, 3, 4"
97100
DiffEqNoiseProcess = "5"
98101
DiffRules = "0.1, 1.0"
102+
DifferentiationInterface = "0.6.47"
99103
Distributed = "1"
100104
Distributions = "0.23, 0.24, 0.25"
101105
DocStringExtensions = "0.7, 0.8, 0.9"
@@ -156,8 +160,8 @@ julia = "1.9"
156160
[extras]
157161
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
158162
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
159-
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
160163
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
164+
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
161165
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
162166
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
163167
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
9494
import DynamicQuantities, Unitful
9595
const DQ = DynamicQuantities
9696

97+
import DifferentiationInterface as DI
98+
using ADTypes: AutoForwardDiff
99+
97100
export @derivatives
98101

99102
for fun in [:toexpr]

src/linearization.jl

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
2525
- `simplify`: Apply simplification in tearing.
2626
- `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed.
2727
- `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point.
28+
- `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians.
2829
- `kwargs`: Are passed on to `find_solvables!`
2930
3031
See also [`linearize`](@ref) which provides a higher-level interface.
@@ -39,6 +40,7 @@ function linearization_function(sys::AbstractSystem, inputs,
3940
p = DiffEqBase.NullParameters(),
4041
zero_dummy_der = false,
4142
initialization_solver_alg = TrustRegion(),
43+
autodiff = AutoForwardDiff(),
4244
eval_expression = false, eval_module = @__MODULE__,
4345
warn_initialize_determined = true,
4446
guesses = Dict(),
@@ -82,13 +84,89 @@ function linearization_function(sys::AbstractSystem, inputs,
8284
initialization_kwargs = (;
8385
abstol = initialization_abstol, reltol = initialization_reltol,
8486
nlsolve_alg = initialization_solver_alg)
87+
88+
p = parameter_values(prob)
89+
t0 = current_time(prob)
90+
inputvals = [p[idx] for idx in input_idxs]
91+
92+
uf_fun = let fun = prob.f
93+
function uff(du, u, p, t)
94+
SciMLBase.UJacobianWrapper(fun, t, p)(du, u)
95+
end
96+
end
97+
uf_jac = PreparedJacobian{true}(uf_fun, similar(prob.u0), autodiff, prob.u0, DI.Constant(p), DI.Constant(t0))
98+
# observed function is a `GeneratedFunctionWrapper` with iip component
99+
h_jac = PreparedJacobian{true}(h, similar(prob.u0, size(outputs)), autodiff, prob.u0, DI.Constant(p), DI.Constant(t0))
100+
pf_fun = let fun = prob.f, setter = setp_oop(sys, input_idxs)
101+
function pff(du, input, u, p, t)
102+
p = setter(p, input)
103+
SciMLBase.ParamJacobianWrapper(fun, t, u)(du, p)
104+
end
105+
end
106+
pf_jac = PreparedJacobian{true}(pf_fun, similar(prob.u0), autodiff, inputvals, DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
107+
hp_fun = let fun = h, setter = setp_oop(sys, input_idxs)
108+
function hpf(du, input, u, p, t)
109+
p = setter(p, input)
110+
fun(du, u, p, t)
111+
return du
112+
end
113+
end
114+
hp_jac = PreparedJacobian{true}(hp_fun, similar(prob.u0, size(outputs)), autodiff, inputvals, DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
115+
85116
lin_fun = LinearizationFunction(
86117
diff_idxs, alge_idxs, input_idxs, length(unknowns(sys)),
87-
prob, h, u0 === nothing ? nothing : similar(u0),
88-
ForwardDiff.Chunk(input_idxs), initializealg, initialization_kwargs)
118+
prob, h, u0 === nothing ? nothing : similar(u0), uf_jac, h_jac, pf_jac,
119+
hp_jac, initializealg, initialization_kwargs)
89120
return lin_fun, sys
90121
end
91122

123+
"""
124+
$(TYPEDEF)
125+
126+
Callable struct which stores a function and its prepared `DI.jacobian`. Calling with the
127+
appropriate arguments for DI returns the jacobian.
128+
129+
# Fields
130+
131+
$(TYPEDFIELDS)
132+
"""
133+
struct PreparedJacobian{iip, P, F, B, A}
134+
"""
135+
The preparation object.
136+
"""
137+
prep::P
138+
"""
139+
The function whose jacobian is calculated.
140+
"""
141+
f::F
142+
"""
143+
Buffer for in-place functions.
144+
"""
145+
buf::B
146+
"""
147+
ADType to use for differentiation.
148+
"""
149+
autodiff::A
150+
end
151+
152+
function PreparedJacobian{true}(f, buf, autodiff, args...)
153+
prep = DI.prepare_jacobian(f, buf, autodiff, args...)
154+
return PreparedJacobian{true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)}(prep, f, buf, autodiff)
155+
end
156+
157+
function PreparedJacobian{false}(f, autodiff, args...)
158+
prep = DI.prepare_jacobian(f, autodiff, args...)
159+
return PreparedJacobian{true, typeof(prep), typeof(f), Nothing, typeof(autodiff)}(prep, f, nothing)
160+
end
161+
162+
function (pj::PreparedJacobian{true})(args...)
163+
DI.jacobian(pj.f, pj.buf, pj.prep, pj.autodiff, args...)
164+
end
165+
166+
function (pj::PreparedJacobian{false})(args...)
167+
DI.jacobian(pj.f, pj.prep, pj.autodiff, args...)
168+
end
169+
92170
"""
93171
$(TYPEDEF)
94172
@@ -100,7 +178,7 @@ $(TYPEDFIELDS)
100178
"""
101179
struct LinearizationFunction{
102180
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, II, P <: ODEProblem,
103-
H, C, Ch, IA <: SciMLBase.DAEInitializationAlgorithm, IK}
181+
H, C, J1, J2, J3, J4, IA <: SciMLBase.DAEInitializationAlgorithm, IK}
104182
"""
105183
The indexes of differential equations in the linearized system.
106184
"""
@@ -130,11 +208,22 @@ struct LinearizationFunction{
130208
Any required cache buffers.
131209
"""
132210
caches::C
133-
# TODO: Use DI?
134211
"""
135-
A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
212+
`PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `u`
213+
"""
214+
uf_jac::J1
215+
"""
216+
`PreparedJacobian` for calculating jacobian of `h` w.r.t. `u`
136217
"""
137-
chunk::Ch
218+
h_jac::J2
219+
"""
220+
`PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `p`
221+
"""
222+
pf_jac::J3
223+
"""
224+
`PreparedJacobian` for calculating jacobian of `h` w.r.t. `p`
225+
"""
226+
hp_jac::J4
138227
"""
139228
The initialization algorithm to use.
140229
"""
@@ -188,25 +277,16 @@ function (linfun::LinearizationFunction)(u, p, t)
188277
if !success
189278
error("Initialization algorithm $(linfun.initializealg) failed with `u = $u` and `p = $p`.")
190279
end
191-
uf = SciMLBase.UJacobianWrapper(fun, t, p)
192-
fg_xz = ForwardDiff.jacobian(uf, u)
193-
h_xz = ForwardDiff.jacobian(
194-
let p = p, t = t, h = linfun.h
195-
xz -> h(xz, p, t)
196-
end, u)
197-
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
198-
fg_u = jacobian_wrt_vars(pf, p, linfun.input_idxs, linfun.chunk)
280+
fg_xz = linfun.uf_jac(u, DI.Constant(p), DI.Constant(t))
281+
h_xz = linfun.h_jac(u, DI.Constant(p), DI.Constant(t))
282+
fg_u = linfun.pf_jac([p[idx] for idx in linfun.input_idxs], DI.Constant(u), DI.Constant(p), DI.Constant(t))
199283
else
200284
linfun.num_states == 0 ||
201285
error("Number of unknown variables (0) does not match the number of input unknowns ($(length(u)))")
202286
fg_xz = zeros(0, 0)
203287
h_xz = fg_u = zeros(0, length(linfun.input_idxs))
204288
end
205-
hp = let u = u, t = t, h = linfun.h
206-
_hp(p) = h(u, p, t)
207-
_hp
208-
end
209-
h_u = jacobian_wrt_vars(hp, p, linfun.input_idxs, linfun.chunk)
289+
h_u = linfun.hp_jac([p[idx] for idx in linfun.input_idxs], DI.Constant(u), DI.Constant(p), DI.Constant(t))
210290
(f_x = fg_xz[linfun.diff_idxs, linfun.diff_idxs],
211291
f_z = fg_xz[linfun.diff_idxs, linfun.alge_idxs],
212292
g_x = fg_xz[linfun.alge_idxs, linfun.diff_idxs],

0 commit comments

Comments
 (0)