Skip to content

Commit e6c9eec

Browse files
fix: handle Initial parameters in code generation
1 parent 19bedde commit e6c9eec

File tree

12 files changed

+48
-39
lines changed

12 files changed

+48
-39
lines changed

src/inputoutput.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
211211
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
212212

213213
dvs = unknowns(sys)
214-
ps = parameters(sys)
214+
ps = parameters(sys; initial_parameters = true)
215215
ps = setdiff(ps, inputs)
216216
if disturbance_inputs !== nothing
217217
# remove from inputs since we do not want them as actual inputs to the dynamics
@@ -234,16 +234,14 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
234234
[eq.rhs for eq in eqs]
235235

236236
# TODO: add an optional check on the ordering of observed equations
237-
u = map(x -> time_varying_as_func(value(x), sys), dvs)
238-
p = map(x -> time_varying_as_func(value(x), sys), ps)
239-
p = reorder_parameters(sys, p)
237+
p = reorder_parameters(sys, ps)
240238
t = get_iv(sys)
241239

242240
# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
243241
if disturbance_argument
244-
args = (u, inputs, p..., t, disturbance_inputs)
242+
args = (dvs, inputs, p..., t, disturbance_inputs)
245243
else
246-
args = (u, inputs, p..., t)
244+
args = (dvs, inputs, p..., t)
247245
end
248246
if implicit_dae
249247
ddvs = map(Differential(get_iv(sys)), dvs)

src/systems/abstractsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ time-independent systems. If `split=true` (the default) was passed to [`complete
161161
object.
162162
"""
163163
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
164-
ps = parameters(sys);
164+
ps = parameters(sys; initial_parameters = true);
165165
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__,
166166
cachesyms::Tuple = (), kwargs...)
167167
if !iscomplete(sys)
@@ -533,7 +533,7 @@ function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSyste
533533
end
534534

535535
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
536-
return parameters(sys)
536+
return parameters(sys; initial_parameters = true)
537537
end
538538

539539
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
@@ -2391,7 +2391,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
23912391
kwargs...)
23922392
sts = unknowns(sys)
23932393
t = get_iv(sys)
2394-
ps = parameters(sys)
2394+
ps = parameters(sys; initial_parameters = true)
23952395
p = reorder_parameters(sys, ps)
23962396

23972397
fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]

src/systems/callbacks.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
682682
end
683683

684684
function generate_rootfinding_callback(sys::AbstractTimeDependentSystem,
685-
dvs = unknowns(sys), ps = parameters(sys); kwargs...)
685+
dvs = unknowns(sys), ps = parameters(sys; initial_parameters = true); kwargs...)
686686
cbs = continuous_events(sys)
687687
isempty(cbs) && return nothing
688688
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
@@ -693,7 +693,7 @@ generate_rootfinding_callback and thus we can produce a ContinuousCallback inste
693693
"""
694694
function generate_single_rootfinding_callback(
695695
eq, cb, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
696-
ps = parameters(sys); kwargs...)
696+
ps = parameters(sys; initial_parameters = true); kwargs...)
697697
if !isequal(eq.lhs, 0)
698698
eq = 0 ~ eq.lhs - eq.rhs
699699
end
@@ -736,7 +736,7 @@ end
736736

737737
function generate_vector_rootfinding_callback(
738738
cbs, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
739-
ps = parameters(sys); rootfind = SciMLBase.RightRootFind,
739+
ps = parameters(sys; initial_parameters = true); rootfind = SciMLBase.RightRootFind,
740740
reinitialization = SciMLBase.CheckInit(), kwargs...)
741741
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
742742
num_eqs = length.(eqs)
@@ -861,7 +861,7 @@ function compile_affect_fn(cb, sys::AbstractTimeDependentSystem, dvs, ps, kwargs
861861
end
862862

863863
function generate_rootfinding_callback(cbs, sys::AbstractTimeDependentSystem,
864-
dvs = unknowns(sys), ps = parameters(sys); kwargs...)
864+
dvs = unknowns(sys), ps = parameters(sys; initial_parameters = true); kwargs...)
865865
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
866866
num_eqs = length.(eqs)
867867
total_eqs = sum(num_eqs)
@@ -949,7 +949,9 @@ function invalid_variables(sys, expr)
949949
end
950950
function unassignable_variables(sys, expr)
951951
assignable_syms = reduce(
952-
vcat, Symbolics.scalarize.(vcat(unknowns(sys), parameters(sys))); init = [])
952+
vcat, Symbolics.scalarize.(vcat(
953+
unknowns(sys), parameters(sys; initial_parameters = true)));
954+
init = [])
953955
written = reduce(vcat, Symbolics.scalarize.(vars(expr)); init = [])
954956
return filter(
955957
x -> !any(isequal(x), assignable_syms), written)
@@ -1075,7 +1077,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
10751077
end
10761078

10771079
function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
1078-
ps = parameters(sys); kwargs...)
1080+
ps = parameters(sys; initial_parameters = true); kwargs...)
10791081
has_discrete_events(sys) || return nothing
10801082
symcbs = discrete_events(sys)
10811083
isempty(symcbs) && return nothing

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ function calculate_control_jacobian(sys::AbstractODESystem;
105105
end
106106

107107
function generate_tgrad(
108-
sys::AbstractODESystem, dvs = unknowns(sys), ps = parameters(sys);
108+
sys::AbstractODESystem, dvs = unknowns(sys), ps = parameters(
109+
sys; initial_parameters = true);
109110
simplify = false, kwargs...)
110111
tgrad = calculate_tgrad(sys, simplify = simplify)
111112
p = reorder_parameters(sys, ps)
@@ -117,7 +118,7 @@ function generate_tgrad(
117118
end
118119

119120
function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
120-
ps = parameters(sys);
121+
ps = parameters(sys; initial_parameters = true);
121122
simplify = false, sparse = false, kwargs...)
122123
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
123124
p = reorder_parameters(sys, ps)
@@ -129,15 +130,15 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
129130
end
130131

131132
function generate_control_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
132-
ps = parameters(sys);
133+
ps = parameters(sys; initial_parameters = true);
133134
simplify = false, sparse = false, kwargs...)
134135
jac = calculate_control_jacobian(sys; simplify = simplify, sparse = sparse)
135136
p = reorder_parameters(sys, ps)
136137
return build_function_wrapper(sys, jac, dvs, p..., get_iv(sys); kwargs...)
137138
end
138139

139140
function generate_dae_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
140-
ps = parameters(sys); simplify = false, sparse = false,
141+
ps = parameters(sys; initial_parameters = true); simplify = false, sparse = false,
141142
kwargs...)
142143
jac_u = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
143144
derivatives = Differential(get_iv(sys)).(unknowns(sys))
@@ -153,7 +154,7 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
153154
end
154155

155156
function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
156-
ps = parameters(sys);
157+
ps = parameters(sys; initial_parameters = true);
157158
implicit_dae = false,
158159
ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) :
159160
nothing,
@@ -691,7 +692,7 @@ function SymbolicTstops(
691692
term(:, t0, unwrap(val), t1; type = AbstractArray{Real})
692693
end
693694
end
694-
rps = reorder_parameters(sys, parameters(sys))
695+
rps = reorder_parameters(sys)
695696
tstops, _ = build_function_wrapper(sys, tstops,
696697
rps...,
697698
t0,
@@ -817,7 +818,7 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
817818
end
818819

819820
function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
820-
p = reorder_parameters(sys, parameters(sys))
821+
p = reorder_parameters(sys)
821822
build_function_wrapper(
822823
sys, u0, p..., get_iv(sys); expression, p_start = 1, p_end = length(p),
823824
similarto = typeof(u0), wrap_delays = false, kwargs...)

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ function build_explicit_observed_function(sys, ts;
424424
eval_module = @__MODULE__,
425425
output_type = Array,
426426
checkbounds = true,
427-
ps = parameters(sys),
427+
ps = parameters(sys; initial_parameters = true),
428428
return_inplace = false,
429429
param_only = false,
430430
op = Operator,

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ function __get_num_diag_noise(mat)
409409
end
410410

411411
function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
412-
ps = parameters(sys); isdde = false, kwargs...)
412+
ps = parameters(sys; initial_parameters = true); isdde = false, kwargs...)
413413
eqs = get_noiseeqs(sys)
414414
p = reorder_parameters(sys, ps)
415415
return build_function_wrapper(sys, eqs, dvs, p..., get_iv(sys); kwargs...)

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ function generate_rate_function(js::JumpSystem, rate)
282282
csubs = Dict(c => getdefault(c) for c in consts)
283283
rate = substitute(rate, csubs)
284284
end
285-
p = reorder_parameters(js, parameters(js))
285+
p = reorder_parameters(js)
286286
rf = build_function_wrapper(js, rate, unknowns(js), p...,
287287
get_iv(js),
288288
expression = Val{true})
@@ -634,7 +634,7 @@ end
634634
function JumpSysMajParamMapper(js::JumpSystem, p; jseqs = nothing, rateconsttype = Float64)
635635
eqs = (jseqs === nothing) ? equations(js) : jseqs
636636
paramexprs = [maj.scaled_rates for maj in eqs.x[1]]
637-
psyms = reduce(vcat, reorder_parameters(js, parameters(js)); init = [])
637+
psyms = reduce(vcat, reorder_parameters(js); init = [])
638638
paramdict = Dict(value(k) => value(v) for (k, v) in zip(psyms, vcat(p...)))
639639
JumpSysMajParamMapper{typeof(paramexprs), typeof(psyms), rateconsttype}(paramexprs,
640640
psyms,

src/systems/nonlinear/initializesystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ end
339339

340340
function ReconstructInitializeprob(
341341
srcsys::AbstractSystem, dstsys::AbstractSystem; remap = Dict())
342-
syms = reduce(vcat, reorder_parameters(dstsys, parameters(dstsys)); init = [])
342+
syms = reduce(
343+
vcat, reorder_parameters(dstsys, parameters(dstsys; initial_parameters = true));
344+
init = [])
343345
getter = getu(srcsys, map(x -> get(remap, x, x), syms))
344346
setter = setp_oop(dstsys, syms)
345347
return ReconstructInitializeprob(getter, setter)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ function calculate_jacobian(sys::NonlinearSystem; sparse = false, simplify = fal
248248
end
249249

250250
function generate_jacobian(
251-
sys::NonlinearSystem, vs = unknowns(sys), ps = parameters(sys);
251+
sys::NonlinearSystem, vs = unknowns(sys), ps = parameters(
252+
sys; initial_parameters = true);
252253
sparse = false, simplify = false, kwargs...)
253254
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
254255
p = reorder_parameters(sys, ps)
@@ -268,15 +269,17 @@ function calculate_hessian(sys::NonlinearSystem; sparse = false, simplify = fals
268269
end
269270

270271
function generate_hessian(
271-
sys::NonlinearSystem, vs = unknowns(sys), ps = parameters(sys);
272+
sys::NonlinearSystem, vs = unknowns(sys), ps = parameters(
273+
sys; initial_parameters = true);
272274
sparse = false, simplify = false, kwargs...)
273275
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
274276
p = reorder_parameters(sys, ps)
275277
return build_function_wrapper(sys, hess, vs, p...; kwargs...)
276278
end
277279

278280
function generate_function(
279-
sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(sys);
281+
sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(
282+
sys; initial_parameters = true);
280283
scalar = false, kwargs...)
281284
rhss = [deq.rhs for deq in equations(sys)]
282285
dvs′ = value.(dvs)
@@ -569,7 +572,7 @@ end
569572
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
570573
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
571574
eval_expression = false, eval_module = @__MODULE__)
572-
ps = parameters(sys)
575+
ps = parameters(sys; initial_parameters = true)
573576
rps = reorder_parameters(sys, ps)
574577
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
575578
body = map(eachindex(buffer_types), buffer_types) do i, T
@@ -589,7 +592,7 @@ struct SCCNonlinearFunction{iip} end
589592
function SCCNonlinearFunction{iip}(
590593
sys::NonlinearSystem, _eqs, _dvs, _obs, cachesyms; eval_expression = false,
591594
eval_module = @__MODULE__, kwargs...) where {iip}
592-
ps = parameters(sys)
595+
ps = parameters(sys; initial_parameters = true)
593596
rps = reorder_parameters(sys, ps)
594597

595598
obs_assignments = [eq.lhs eq.rhs for eq in _obs]

src/systems/optimization/constraints_system.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ function calculate_jacobian(sys::ConstraintsSystem; sparse = false, simplify = f
171171
end
172172

173173
function generate_jacobian(
174-
sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(sys);
174+
sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(
175+
sys; initial_parameters = true);
175176
sparse = false, simplify = false, kwargs...)
176177
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
177178
p = reorder_parameters(sys, ps)
@@ -190,15 +191,16 @@ function calculate_hessian(sys::ConstraintsSystem; sparse = false, simplify = fa
190191
end
191192

192193
function generate_hessian(
193-
sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(sys);
194+
sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(
195+
sys; initial_parameters = true);
194196
sparse = false, simplify = false, kwargs...)
195197
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
196198
p = reorder_parameters(sys, ps)
197199
return build_function_wrapper(sys, hess, vs, p...; kwargs...)
198200
end
199201

200202
function generate_function(sys::ConstraintsSystem, dvs = unknowns(sys),
201-
ps = parameters(sys);
203+
ps = parameters(sys; initial_parameters = true);
202204
kwargs...)
203205
lhss = generate_canonical_form_lhss(sys)
204206
p = reorder_parameters(sys, value.(ps))

0 commit comments

Comments
 (0)