Skip to content

Commit 0b84d17

Browse files
committed
Add a pass that converts unbound inputs to parameters
1 parent 28b36c8 commit 0b84d17

File tree

5 files changed

+72
-2
lines changed

5 files changed

+72
-2
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ using .BipartiteGraphs
115115

116116
include("variables.jl")
117117
include("parameters.jl")
118-
include("inputoutput.jl")
119118

120119
include("utils.jl")
121120
include("domains.jl")
@@ -152,6 +151,7 @@ include("systems/alias_elimination.jl")
152151
include("structural_transformation/StructuralTransformations.jl")
153152

154153
@reexport using .StructuralTransformations
154+
include("inputoutput.jl")
155155

156156
for S in subtypes(ModelingToolkit.AbstractSystem)
157157
S = nameof(S)

src/inputoutput.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,65 @@ function toparam(sys, ctrls::AbstractVector)
239239
end
240240
ODESystem(eqs, name = nameof(sys))
241241
end
242+
243+
function inputs_to_parameters!(state::TransformationState)
244+
@unpack structure, fullvars, sys = state
245+
@unpack var_to_diff, graph, solvable_graph = structure
246+
@assert solvable_graph === nothing
247+
248+
inputs = BitSet()
249+
var_reidx = zeros(Int, length(fullvars))
250+
ninputs = 0
251+
nvar = 0
252+
new_parameters = []
253+
input_to_parameters = Dict()
254+
new_fullvars = []
255+
for (i, v) in enumerate(fullvars)
256+
if isinput(v) && !is_bound(sys, v)
257+
if var_to_diff[i] !== nothing
258+
error("Input $(fullvars[i]) is differentiated!")
259+
end
260+
push!(inputs, i)
261+
ninputs += 1
262+
var_reidx[i] = -1
263+
p = toparam(v)
264+
push!(new_parameters, p)
265+
input_to_parameters[v] = p
266+
else
267+
nvar += 1
268+
var_reidx[i] = nvar
269+
push!(new_fullvars, v)
270+
end
271+
end
272+
ninputs == 0 && return state
273+
274+
nvars = ndsts(graph) - ninputs
275+
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
276+
277+
for ie in 1:nsrcs(graph)
278+
for iv in 𝑠neighbors(graph, ie)
279+
iv = var_reidx[iv]
280+
iv > 0 || continue
281+
add_edge!(new_graph, ie, iv)
282+
end
283+
end
284+
285+
new_var_to_diff = DiffGraph(nvars, true)
286+
for (i, v) in enumerate(var_to_diff)
287+
new_i = var_reidx[i]
288+
(new_i < 1 || v === nothing) && continue
289+
new_v = var_reidx[v]
290+
@assert new_v > 0
291+
new_var_to_diff[new_i] = new_v
292+
end
293+
@set! structure.var_to_diff = new_var_to_diff
294+
@set! structure.graph = new_graph
295+
296+
@set! sys.eqs = map(Base.Fix2(substitute, input_to_parameters), equations(sys))
297+
@set! sys.states = setdiff(states(sys), keys(input_to_parameters))
298+
@set! sys.ps = [parameters(sys); new_parameters]
299+
300+
@set! state.sys = sys
301+
@set! state.fullvars = new_fullvars
302+
@set! state.structure = structure
303+
end

src/structural_transformation/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ function find_eq_solvables!(state::TearingState, ieq; may_be_zero = false,
170170
to_rm = Int[]
171171
for j in 𝑠neighbors(graph, ieq)
172172
var = fullvars[j]
173-
isinput(var) && continue
173+
#isinput(var) && continue
174174
a, b, islinear = linear_expansion(term, var)
175175
a = unwrap(a)
176176
islinear || continue

src/systems/abstractsystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,8 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
10021002
sys = expand_connections(sys)
10031003
sys = alias_elimination(sys)
10041004
state = TearingState(sys)
1005+
state = inputs_to_parameters!(state)
1006+
sys = state.sys
10051007
check_consistency(state)
10061008
if sys isa ODESystem
10071009
sys = dae_order_lowering(dummy_derivative(sys, state))

test/input_output_handling.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,9 @@ p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)
161161
x = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), dvs)
162162
u = [rand()]
163163
@test f[1](x, u, p, 1) == [u; 0; 0; 0]
164+
165+
@parameters t
166+
@variables x(t) u(t) [input=true]
167+
eqs = [Differential(t)(x) ~ u]
168+
@named sys = ODESystem(eqs, t)
169+
structural_simplify(sys)

0 commit comments

Comments
 (0)