Skip to content

Commit 4a49486

Browse files
fix: linearization with MTKParameters
1 parent 1c1dd75 commit 4a49486

File tree

7 files changed

+97
-57
lines changed

7 files changed

+97
-57
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ include("constants.jl")
129129
include("utils.jl")
130130
include("domains.jl")
131131

132+
include("systems/index_cache.jl")
133+
include("systems/parameter_buffer.jl")
132134
include("systems/abstractsystem.jl")
133135
include("systems/model_parsing.jl")
134136
include("systems/connectors.jl")
135137
include("systems/callbacks.jl")
136-
include("systems/index_cache.jl")
137-
include("systems/parameter_buffer.jl")
138138

139139
include("systems/diffeqs/odesystem.jl")
140140
include("systems/diffeqs/sdesystem.jl")

src/systems/abstractsystem.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,7 @@ function linearization_function(sys::AbstractSystem, inputs,
15041504
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
15051505
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
15061506
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1507-
p = (MTKParameters(sys, p)...,)
1507+
p = MTKParameters(sys, p)
15081508
ps = reorder_parameters(sys, parameters(sys))
15091509
else
15101510
p = _p
@@ -1515,7 +1515,6 @@ function linearization_function(sys::AbstractSystem, inputs,
15151515
ps = (ps...,) #if p is Tuple, ps should be Tuple
15161516
end
15171517
end
1518-
15191518
lin_fun = let diff_idxs = diff_idxs,
15201519
alge_idxs = alge_idxs,
15211520
input_idxs = input_idxs,
@@ -1550,7 +1549,9 @@ function linearization_function(sys::AbstractSystem, inputs,
15501549
h_xz = fg_u = zeros(0, length(inputs))
15511550
end
15521551
hp = let u = u, t = t
1553-
p -> h(u, p, t)
1552+
_hp(p) = h(u, p, t)
1553+
_hp(p::MTKParameters) = h(u, p..., t)
1554+
_hp
15541555
end
15551556
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
15561557
(f_x = fg_xz[diff_idxs, diff_idxs],
@@ -1592,13 +1593,14 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
15921593
kwargs...)
15931594
sts = unknowns(sys)
15941595
t = get_iv(sys)
1595-
p = parameters(sys)
1596+
ps = parameters(sys)
1597+
p = reorder_parameters(sys, ps)
15961598

1597-
fun = generate_function(sys, sts, p; expression = Val{false})[1]
1598-
dx = fun(sts, p, t)
1599+
fun = generate_function(sys, sts, ps; expression = Val{false})[1]
1600+
dx = fun(sts, p..., t)
15991601

16001602
h = build_explicit_observed_function(sys, outputs)
1601-
y = h(sts, p, t)
1603+
y = h(sts, p..., t)
16021604

16031605
fg_xz = Symbolics.jacobian(dx, sts)
16041606
fg_u = Symbolics.jacobian(dx, inputs)
@@ -1794,6 +1796,15 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
17941796
x0 = merge(defaults(sys), op)
17951797
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
17961798
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1799+
if p isa SciMLBase.NullParameters
1800+
p = op
1801+
elseif p isa Dict
1802+
p = merge(p, op)
1803+
elseif p isa Vector && eltype(p) <: Pair
1804+
p = merge(Dict(p), op)
1805+
elseif p isa Vector
1806+
p = merge(Dict(parameters(sys) .=> p), op)
1807+
end
17971808
p2 = MTKParameters(sys, p)
17981809
end
17991810
linres = lin_fun(u0, p2, t)

src/systems/parameter_buffer.jl

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
struct MTKParameters{T, D, C, E, F}
1+
struct MTKParameters{T, D, C, E, F, G}
22
tunable::T
33
discrete::D
44
constant::C
55
dependent::E
6-
dependent_update::F
6+
dependent_update_iip::F
7+
dependent_update_oop::G
78
end
89

910
function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = false)
@@ -19,12 +20,12 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
1920
length(p) == length(ps) || error("Invalid parameters")
2021
p = ps .=> p
2122
end
22-
defs = Dict(default_toterm(unwrap(k)) => v for (k, v) in defaults(sys) if unwrap(k) in all_ps)
23+
defs = Dict(default_toterm(unwrap(k)) => v for (k, v) in defaults(sys) if unwrap(k) in all_ps || default_toterm(unwrap(k)) in all_ps)
2324
if p isa SciMLBase.NullParameters
2425
p = defs
2526
else
26-
extra_params = Dict(unwrap(k) => v for (k, v) in p if !in(unwrap(k), all_ps))
27-
p = merge(defs, Dict(default_toterm(unwrap(k)) => v for (k, v) in p if unwrap(k) in all_ps))
27+
extra_params = Dict(unwrap(k) => v for (k, v) in p if !in(unwrap(k), all_ps) && !in(default_toterm(unwrap(k)), all_ps))
28+
p = merge(defs, Dict(default_toterm(unwrap(k)) => v for (k, v) in p if unwrap(k) in all_ps || default_toterm(unwrap(k)) in all_ps))
2829
p = Dict(k => fixpoint_sub(v, extra_params) for (k, v) in p if !haskey(extra_params, unwrap(k)))
2930
end
3031

@@ -74,39 +75,41 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
7475
dep_exprs[idx] = wrap(fixpoint_sub(val, dependencies))
7576
end
7677
p = reorder_parameters(ic, parameters(sys))[begin:end-length(dep_buffer.x)]
77-
update_function = if isempty(dep_exprs.x)
78-
(_...) -> ()
78+
update_function_iip, update_function_oop = if isempty(dep_exprs.x)
79+
nothing, nothing
7980
else
80-
RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(build_function(dep_exprs, p...)[2])
81+
oop, iip = build_function(dep_exprs, p...)
82+
RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip), RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(oop)
8183
end
8284
# everything is an ArrayPartition so it's easy to figure out how many
8385
# distinct vectors we have for each portion as `ArrayPartition.x`
8486
if isempty(tunable_buffer.x)
85-
tunable_buffer = ArrayPartition(Float64[])
87+
tunable_buffer = Float64[]
8688
end
8789
if isempty(disc_buffer.x)
88-
disc_buffer = ArrayPartition(Float64[])
90+
disc_buffer = Float64[]
8991
end
9092
if isempty(const_buffer.x)
91-
const_buffer = ArrayPartition(Float64[])
93+
const_buffer = Float64[]
9294
end
9395
if isempty(dep_buffer.x)
94-
dep_buffer = ArrayPartition(Float64[])
96+
dep_buffer = Float64[]
9597
end
9698
if use_union
97-
tunable_buffer = ArrayPartition(restrict_array_to_union(tunable_buffer))
98-
disc_buffer = ArrayPartition(restrict_array_to_union(disc_buffer))
99-
const_buffer = ArrayPartition(restrict_array_to_union(const_buffer))
100-
dep_buffer = ArrayPartition(restrict_array_to_union(dep_buffer))
99+
tunable_buffer = restrict_array_to_union(tunable_buffer)
100+
disc_buffer = restrict_array_to_union(disc_buffer)
101+
const_buffer = restrict_array_to_union(const_buffer)
102+
dep_buffer = restrict_array_to_union(dep_buffer)
101103
elseif tofloat
102-
tunable_buffer = ArrayPartition(Float64.(tunable_buffer))
103-
disc_buffer = ArrayPartition(Float64.(disc_buffer))
104-
const_buffer = ArrayPartition(Float64.(const_buffer))
105-
dep_buffer = ArrayPartition(Float64.(dep_buffer))
104+
tunable_buffer = Float64.(tunable_buffer)
105+
disc_buffer = Float64.(disc_buffer)
106+
const_buffer = Float64.(const_buffer)
107+
dep_buffer = Float64.(dep_buffer)
106108
end
107109
return MTKParameters{typeof(tunable_buffer), typeof(disc_buffer), typeof(const_buffer),
108-
typeof(dep_buffer), typeof(update_function)}(tunable_buffer,
109-
disc_buffer, const_buffer, dep_buffer, update_function)
110+
typeof(dep_buffer), typeof(update_function_iip), typeof(update_function_oop)}(
111+
tunable_buffer, disc_buffer, const_buffer, dep_buffer, update_function_iip,
112+
update_function_oop)
110113
end
111114

112115
SciMLStructures.isscimlstructure(::MTKParameters) = true
@@ -121,19 +124,27 @@ for (Portion, field) in [
121124
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
122125
function repack(values)
123126
p.$field .= values
124-
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
127+
if p.dependent_update_iip !== nothing
128+
p.dependent_update_iip(p.dependent, p...)
129+
end
130+
p
125131
end
126132
return p.$field, repack, true
127133
end
128134

129135
@eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals)
130136
@set! p.$field = newvals
131-
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
137+
if p.dependent_update_oop !== nothing
138+
@set! p.dependent = ArrayPartition(p.dependent_update_oop(p...))
139+
end
140+
p
132141
end
133142

134143
@eval function SciMLStructures.replace!(::$Portion, p::MTKParameters, newvals)
135144
p.$field .= newvals
136-
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
145+
if p.dependent_update_iip !== nothing
146+
p.dependent_update_iip(p.dependent, p...)
147+
end
137148
nothing
138149
end
139150
end
@@ -166,26 +177,32 @@ function SymbolicIndexingInterface.set_parameter!(p::MTKParameters, val, idx::Pa
166177
else
167178
error("Unhandled portion $portion")
168179
end
169-
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
180+
if p.dependent_update_iip !== nothing
181+
p.dependent_update_iip(p.dependent, p...)
182+
end
170183
end
171184

185+
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
186+
_subarrays(v::ArrayPartition) = v.x
187+
_num_subarrays(v::AbstractVector) = 1
188+
_num_subarrays(v::ArrayPartition) = length(v.x)
172189
# for compiling callbacks
173190
# getindex indexes the vectors, setindex! linearly indexes values
174191
# it's inconsistent, but we need it to be this way
175192
function Base.getindex(buf::MTKParameters, i)
176193
if !isempty(buf.tunable)
177-
i <= length(buf.tunable.x) && return buf.tunable.x[i]
178-
i -= length(buf.tunable.x)
194+
i <= _num_subarrays(buf.tunable) && return _subarrays(buf.tunable)[i]
195+
i -= _num_subarrays(buf.tunable)
179196
end
180197
if !isempty(buf.discrete)
181-
i <= length(buf.discrete.x) && return buf.discrete.x[i]
182-
i -= length(buf.discrete.x)
198+
i <= _num_subarrays(buf.discrete) && return _subarrays(buf.discrete)[i]
199+
i -= _num_subarrays(buf.discrete)
183200
end
184201
if !isempty(buf.constant)
185-
i <= length(buf.constant.x) && return buf.constant.x[i]
186-
i -= length(buf.constant.x)
202+
i <= _num_subarrays(buf.constant) && return _subarrays(buf.constant)[i]
203+
i -= _num_subarrays(buf.constant)
187204
end
188-
isempty(buf.dependent) || return buf.dependent.x[i]
205+
isempty(buf.dependent) || return _subarrays(buf.dependent)[i]
189206
throw(BoundsError(buf, i))
190207
end
191208
function Base.setindex!(buf::MTKParameters, val, i)
@@ -196,15 +213,17 @@ function Base.setindex!(buf::MTKParameters, val, i)
196213
else
197214
buf.constant[i - length(buf.tunable) - length(buf.discrete)] = val
198215
end
199-
buf.dependent_update(buf.dependent, buf.tunable.x..., buf.discrete.x..., buf.constant.x...)
216+
if buf.dependent_update_iip !== nothing
217+
buf.dependent_update_iip(buf.dependent, buf...)
218+
end
200219
end
201220

202221
function Base.iterate(buf::MTKParameters, state = 1)
203222
total_len = 0
204-
isempty(buf.tunable) || (total_len += length(buf.tunable.x))
205-
isempty(buf.discrete) || (total_len += length(buf.discrete.x))
206-
isempty(buf.constant) || (total_len += length(buf.constant.x))
207-
isempty(buf.dependent) || (total_len += length(buf.dependent.x))
223+
isempty(buf.tunable) || (total_len += _num_subarrays(buf.tunable))
224+
isempty(buf.discrete) || (total_len += _num_subarrays(buf.discrete))
225+
isempty(buf.constant) || (total_len += _num_subarrays(buf.constant))
226+
isempty(buf.dependent) || (total_len += _num_subarrays(buf.dependent))
208227
if state <= total_len
209228
return (buf[state], state + 1)
210229
else
@@ -229,14 +248,24 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
229248
p_big = p_big
230249

231250
function (p_small_inner)
232-
tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
233-
tunable[input_idxs] .= p_small_inner
234-
p_big = repack(tunable)
235-
pf(p_big)
251+
for (i, val) in zip(input_idxs, p_small_inner)
252+
set_parameter!(p_big, val, i)
253+
end
254+
# tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
255+
# tunable[input_idxs] .= p_small_inner
256+
# p_big = repack(tunable)
257+
return if pf isa SciMLBase.ParamJacobianWrapper
258+
buffer = similar(p_big.tunable, size(pf.u))
259+
pf(buffer, p_big)
260+
buffer
261+
else
262+
pf(p_big)
263+
end
236264
end
237265
end
238-
tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
239-
p_small = tunable[input_idxs]
266+
# tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
267+
# p_small = tunable[input_idxs]
268+
p_small = parameter_values.((p,), input_idxs)
240269
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
241270
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
242271
end

src/systems/systems.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false
2828
@set! newsys.parent = complete(sys)
2929
newsys = complete(newsys)
3030
if newsys′ isa Tuple
31-
return newsys, newsys′[2]
31+
idxs = [parameter_index(newsys, i) for i in io[1]]
32+
return newsys, idxs
3233
else
3334
return newsys
3435
end

test/inversemodel.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ nsys = get_named_sensitivity(model, :y; op) # Test that we get the same result w
157157
# Test the same thing for comp sensitivities
158158

159159
Sf, simplified_sys = Blocks.get_comp_sensitivity_function(model, :y) # This should work without providing an operating opint containing a dummy derivative
160-
x, p = ModelingToolkit.get_u0_p(simplified_sys, op)
160+
x, _ = ModelingToolkit.get_u0_p(simplified_sys, op)
161+
p = ModelingToolkit.MTKParameters(simplified_sys, op)
161162
matrices1 = Sf(x, p, 0)
162163
matrices2, _ = Blocks.get_comp_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
163164
@test matrices1.f_x matrices2.A[1:7, 1:7]

test/linearize.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ function saturation(; y_max, y_min = y_max > 0 ? -y_max : -Inf, name)
170170
]
171171
ODESystem(eqs, t, name = name)
172172
end
173-
174173
@named sat = saturation(; y_max = 1)
175174
# inside the linear region, the function is identity
176175
@unpack u, y = sat

test/symbolic_events.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,6 @@ let rng = rng
498498
function testsol(jsys, u0, p, tspan; tstops = Float64[], paramtotest = nothing,
499499
N = 40000, kwargs...)
500500
jsys = complete(jsys)
501-
@show ModelingToolkit.get_index_cache(jsys)
502501
dprob = DiscreteProblem(jsys, u0, tspan, p)
503502
jprob = JumpProblem(jsys, dprob, Direct(); kwargs...)
504503
sol = solve(jprob, SSAStepper(); tstops = tstops)

0 commit comments

Comments
 (0)