Skip to content

Commit b150fe2

Browse files
Merge pull request #3641 from AayushSabharwal/as/no-constants
refactor: make `@constants` create `tunable = false` parameters
2 parents f543afa + 83a7cb0 commit b150fe2

27 files changed

+95
-546
lines changed

src/constants.jl

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
import SymbolicUtils: symtype, term, hasmetadata, issym
2-
struct MTKConstantCtx end
3-
4-
isconstant(x::Num) = isconstant(unwrap(x))
51
"""
62
Test whether `x` is a constant-type Sym.
73
"""
84
function isconstant(x)
95
x = unwrap(x)
10-
x isa Symbolic && getmetadata(x, MTKConstantCtx, false)
6+
x isa Symbolic && !getmetadata(x, VariableTunable, true)
117
end
128

139
"""
@@ -16,12 +12,11 @@ end
1612
Maps the parameter to a constant. The parameter must have a default.
1713
"""
1814
function toconstant(s)
19-
hasmetadata(s, Symbolics.VariableDefaultValue) ||
20-
throw(ArgumentError("Constant `$(s)` must be assigned a default value."))
21-
setmetadata(s, MTKConstantCtx, true)
15+
s = toparam(s)
16+
setmetadata(s, VariableTunable, false)
2217
end
2318

24-
toconstant(s::Num) = wrap(toconstant(value(s)))
19+
toconstant(s::Union{Num, Symbolics.Arr}) = wrap(toconstant(value(s)))
2520

2621
"""
2722
$(SIGNATURES)
@@ -36,15 +31,3 @@ macro constants(xs...)
3631
xs,
3732
toconstant) |> esc
3833
end
39-
40-
"""
41-
Substitute all `@constants` in the given expression
42-
"""
43-
function subs_constants(eqs)
44-
consts = collect_constants(eqs)
45-
if !isempty(consts)
46-
csubs = Dict(c => getdefault(c) for c in consts)
47-
eqs = substitute(eqs, csubs)
48-
end
49-
return eqs
50-
end

src/inputoutput.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs(
222222
disturbance_inputs = unwrap.(disturbance_inputs)
223223

224224
eqs = [eq for eq in full_equations(sys)]
225-
eqs = map(subs_constants, eqs)
225+
226226
if disturbance_inputs !== nothing && !disturbance_argument
227227
# Set all disturbance *inputs* to zero (we just want to keep the disturbance state)
228228
subs = Dict(disturbance_inputs .=> 0)
@@ -237,7 +237,6 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs(
237237
p = reorder_parameters(sys, ps)
238238
t = get_iv(sys)
239239

240-
# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
241240
if disturbance_argument
242241
args = (dvs, inputs, p..., t, disturbance_inputs)
243242
else

src/problems/optimizationproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ function SciMLBase.OptimizationFunction{iip}(sys::System;
5656
else
5757
_cons_h = cons_hess_prototype = nothing
5858
end
59-
cons_expr = subs_constants(cstr)
59+
cons_expr = cstr
6060
end
6161

62-
obj_expr = subs_constants(cost(sys))
62+
obj_expr = cost(sys)
6363

6464
observedfun = ObservedFunctionCache(
6565
sys; expression, eval_expression, eval_module, checkbounds, cse)

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe
2121
has_tearing_state, defaults, InvalidSystemException,
2222
ExtraEquationsSystemException,
2323
ExtraVariablesSystemException,
24-
get_postprocess_fbody, vars!,
24+
vars!,
2525
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2626
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
2727
filter_kwargs, lower_varname_with_unit,
Lines changed: 1 addition & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra
22

3-
using ModelingToolkit: process_events, get_preprocess_constants
3+
using ModelingToolkit: process_events
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

@@ -96,136 +96,6 @@ function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_
9696
sparse(I, J, true, length(eqs_idxs), length(states_idxs))
9797
end
9898

99-
function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDict,
100-
assignments, (deps, invdeps), var2assignment; checkbounds = true)
101-
isempty(vars) && throw(ArgumentError("vars may not be empty"))
102-
length(eqs) == length(vars) ||
103-
throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
104-
rhss = map(x -> x.rhs, eqs)
105-
# We use `vars` instead of `graph` to capture parameters, too.
106-
paramset = ModelingToolkit.vars(r for r in rhss)
107-
108-
# Compute necessary assignments for the nlsolve expr
109-
init_assignments = [var2assignment[p] for p in paramset if haskey(var2assignment, p)]
110-
if isempty(init_assignments)
111-
needed_assignments_idxs = Int[]
112-
needed_assignments = similar(assignments, 0)
113-
else
114-
tmp = [init_assignments]
115-
# `deps[init_assignments]` gives the dependency of `init_assignments`
116-
while true
117-
next_assignments = unique(reduce(vcat, deps[init_assignments]))
118-
isempty(next_assignments) && break
119-
init_assignments = next_assignments
120-
push!(tmp, init_assignments)
121-
end
122-
needed_assignments_idxs = unique(reduce(vcat, reverse(tmp)))
123-
needed_assignments = assignments[needed_assignments_idxs]
124-
end
125-
126-
# Compute `params`. They are like enclosed variables
127-
rhsvars = [ModelingToolkit.vars(r.rhs) for r in needed_assignments]
128-
vars_set = Set(vars)
129-
outer_set = BitSet()
130-
inner_set = BitSet()
131-
for (i, vs) in enumerate(rhsvars)
132-
j = needed_assignments_idxs[i]
133-
if isdisjoint(vars_set, vs)
134-
push!(outer_set, j)
135-
else
136-
push!(inner_set, j)
137-
end
138-
end
139-
init_refine = BitSet()
140-
for i in inner_set
141-
union!(init_refine, invdeps[i])
142-
end
143-
intersect!(init_refine, outer_set)
144-
setdiff!(outer_set, init_refine)
145-
union!(inner_set, init_refine)
146-
147-
next_refine = BitSet()
148-
while true
149-
for i in init_refine
150-
id = invdeps[i]
151-
isempty(id) && break
152-
union!(next_refine, id)
153-
end
154-
intersect!(next_refine, outer_set)
155-
isempty(next_refine) && break
156-
setdiff!(outer_set, next_refine)
157-
union!(inner_set, next_refine)
158-
159-
init_refine, next_refine = next_refine, init_refine
160-
empty!(next_refine)
161-
end
162-
global2local = Dict(j => i for (i, j) in enumerate(needed_assignments_idxs))
163-
inner_idxs = [global2local[i] for i in collect(inner_set)]
164-
outer_idxs = [global2local[i] for i in collect(outer_set)]
165-
extravars = reduce(union!, rhsvars[inner_idxs], init = Set())
166-
union!(paramset, extravars)
167-
setdiff!(paramset, vars)
168-
setdiff!(paramset, [needed_assignments[i].lhs for i in inner_idxs])
169-
union!(paramset, [needed_assignments[i].lhs for i in outer_idxs])
170-
params = collect(paramset)
171-
172-
# splatting to tighten the type
173-
u0 = []
174-
for v in vars
175-
v in keys(u0map) || (push!(u0, 1e-3); continue)
176-
u = substitute(v, u0map)
177-
for i in 1:length(u0map)
178-
u = substitute(u, u0map)
179-
u isa Number && (push!(u0, u); break)
180-
end
181-
u isa Number || error("$v doesn't have a default.")
182-
end
183-
u0 = [u0...]
184-
# specialize on the scalar case
185-
isscalar = length(u0) == 1
186-
u0 = isscalar ? u0[1] : SVector(u0...)
187-
188-
fname = gensym("fun")
189-
# f is the function to find roots on
190-
if isscalar
191-
funex = rhss[1]
192-
pre = get_preprocess_constants(funex)
193-
else
194-
funex = MakeArray(rhss, SVector)
195-
pre = get_preprocess_constants(rhss)
196-
end
197-
f = Func(
198-
[DestructuredArgs(vars, inbounds = !checkbounds)
199-
DestructuredArgs(params, inbounds = !checkbounds)],
200-
[],
201-
pre(Let(needed_assignments[inner_idxs],
202-
funex,
203-
false))) |> SymbolicUtils.Code.toexpr
204-
205-
# solver call contains code to call the root-finding solver on the function f
206-
solver_call = LiteralExpr(quote
207-
$numerical_nlsolve($fname,
208-
# initial guess
209-
$u0,
210-
# "captured variables"
211-
($(params...),))
212-
end)
213-
214-
preassignments = []
215-
for i in outer_idxs
216-
ii = needed_assignments_idxs[i]
217-
is_not_prepended_assignment[ii] || continue
218-
is_not_prepended_assignment[ii] = false
219-
push!(preassignments, assignments[ii])
220-
end
221-
222-
nlsolve_expr = Assignment[preassignments
223-
fname drop_expr(@RuntimeGeneratedFunction(f))
224-
DestructuredArgs(vars, inbounds = !checkbounds) solver_call]
225-
226-
nlsolve_expr
227-
end
228-
22999
"""
230100
find_solve_sequence(sccs, vars)
231101
@@ -242,136 +112,3 @@ function find_solve_sequence(sccs, vars)
242112
return find_solve_sequence(sccs, vars′)
243113
end
244114
end
245-
246-
function build_observed_function(state, ts, var_eq_matching, var_sccs,
247-
is_solver_unknown_idxs,
248-
assignments,
249-
deps,
250-
sol_states,
251-
var2assignment;
252-
expression = false,
253-
output_type = Array,
254-
checkbounds = true)
255-
is_not_prepended_assignment = trues(length(assignments))
256-
if (isscalar = !(ts isa AbstractVector))
257-
ts = [ts]
258-
end
259-
ts = unwrap.(Symbolics.scalarize(ts))
260-
261-
vars = Set()
262-
sys = state.sys
263-
foreach(Base.Fix1(vars!, vars), ts)
264-
ivs = independent_variables(sys)
265-
dep_vars = collect(setdiff(vars, ivs))
266-
267-
fullvars = state.fullvars
268-
s = state.structure
269-
unknown_vars = fullvars[is_solver_unknown_idxs]
270-
algvars = fullvars[.!is_solver_unknown_idxs]
271-
272-
required_algvars = Set(intersect(algvars, vars))
273-
obs = observed(sys)
274-
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
275-
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
276-
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
277-
sts = Set(unknowns(sys))
278-
279-
# FIXME: This is a rather rough estimate of dependencies. We assume
280-
# the expression depends on everything before the `maxidx`.
281-
subs = Dict()
282-
maxidx = 0
283-
for (i, s) in enumerate(dep_vars)
284-
idx = get(observed_idx, s, nothing)
285-
if idx !== nothing
286-
idx > maxidx && (maxidx = idx)
287-
else
288-
s′ = get(namespaced_to_obs, s, nothing)
289-
if s′ !== nothing
290-
subs[s] = s′
291-
s = s′
292-
idx = get(observed_idx, s, nothing)
293-
end
294-
if idx !== nothing
295-
idx > maxidx && (maxidx = idx)
296-
elseif !(s in sts)
297-
s′ = get(namespaced_to_sts, s, nothing)
298-
if s′ !== nothing
299-
subs[s] = s′
300-
continue
301-
end
302-
throw(ArgumentError("$s is either an observed nor an unknown variable."))
303-
end
304-
continue
305-
end
306-
end
307-
ts = map(t -> substitute(t, subs), ts)
308-
vs = Set()
309-
for idx in 1:maxidx
310-
vars!(vs, obs[idx].rhs)
311-
union!(required_algvars, intersect(algvars, vs))
312-
empty!(vs)
313-
end
314-
for eq in assignments
315-
vars!(vs, eq.rhs)
316-
union!(required_algvars, intersect(algvars, vs))
317-
empty!(vs)
318-
end
319-
320-
varidxs = findall(x -> x in required_algvars, fullvars)
321-
subset = find_solve_sequence(var_sccs, varidxs)
322-
if !isempty(subset)
323-
eqs = equations(sys)
324-
325-
nested_torn_vars_idxs = []
326-
for iscc in subset
327-
torn_vars_idxs = Int[var
328-
for var in var_sccs[iscc]
329-
if var_eq_matching[var] !== unassigned]
330-
isempty(torn_vars_idxs) || push!(nested_torn_vars_idxs, torn_vars_idxs)
331-
end
332-
torn_eqs = [[eqs[var_eq_matching[i]] for i in idxs]
333-
for idxs in nested_torn_vars_idxs]
334-
torn_vars = [fullvars[idxs] for idxs in nested_torn_vars_idxs]
335-
u0map = defaults(sys)
336-
assignments = copy(assignments)
337-
solves = map(zip(torn_eqs, torn_vars)) do (eqs, vars)
338-
gen_nlsolve!(is_not_prepended_assignment, eqs, vars,
339-
u0map, assignments, deps, var2assignment;
340-
checkbounds = checkbounds)
341-
end
342-
else
343-
solves = []
344-
end
345-
346-
subs = []
347-
for sym in vars
348-
eqidx = get(observed_idx, sym, nothing)
349-
eqidx === nothing && continue
350-
push!(subs, sym obs[eqidx].rhs)
351-
end
352-
pre = get_postprocess_fbody(sys)
353-
cpre = get_preprocess_constants([obs[1:maxidx];
354-
isscalar ? ts[1] : MakeArray(ts, output_type)])
355-
pre2 = x -> pre(cpre(x))
356-
ex = Code.toexpr(
357-
Func(
358-
[DestructuredArgs(unknown_vars, inbounds = !checkbounds)
359-
DestructuredArgs(parameters(sys), inbounds = !checkbounds)
360-
independent_variables(sys)],
361-
[],
362-
pre2(Let(
363-
[collect(Iterators.flatten(solves))
364-
assignments[is_not_prepended_assignment]
365-
map(eq -> eq.lhs eq.rhs, obs[1:maxidx])
366-
subs],
367-
isscalar ? ts[1] : MakeArray(ts, output_type),
368-
false))),
369-
sol_states)
370-
371-
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
372-
end
373-
374-
struct ODAEProblem{iip} end
375-
376-
@deprecate ODAEProblem(args...; kw...) ODEProblem(args...; kw...)
377-
@deprecate ODAEProblem{iip}(args...; kw...) where {iip} ODEProblem{iip}(args...; kw...)

src/structural_transformation/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,12 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
224224
a, b, islinear = linear_expansion(term, var)
225225
a, b = unwrap(a), unwrap(b)
226226
islinear || (all_int_vars = false; continue)
227-
a = ModelingToolkit.fold_constants(a)
228-
b = ModelingToolkit.fold_constants(b)
229227
if a isa Symbolic
230228
all_int_vars = false
231229
if !allow_symbolic
232230
if allow_parameter
233231
all(
234-
x -> ModelingToolkit.isparameter(x) || ModelingToolkit.isconstant(x),
232+
x -> ModelingToolkit.isparameter(x),
235233
vars(a)) || continue
236234
else
237235
continue

0 commit comments

Comments
 (0)