Skip to content

Commit 4f999b1

Browse files
authored
Merge pull request #1599 from SciML/myb/io
Add a pass that converts unbound inputs to parameters
2 parents 56e5844 + b20d4dc commit 4f999b1

File tree

6 files changed

+85
-4
lines changed

6 files changed

+85
-4
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ using .BipartiteGraphs
115115

116116
include("variables.jl")
117117
include("parameters.jl")
118-
include("inputoutput.jl")
119118

120119
include("utils.jl")
121120
include("domains.jl")
@@ -152,6 +151,7 @@ include("systems/alias_elimination.jl")
152151
include("structural_transformation/StructuralTransformations.jl")
153152

154153
@reexport using .StructuralTransformations
154+
include("inputoutput.jl")
155155

156156
for S in subtypes(ModelingToolkit.AbstractSystem)
157157
S = nameof(S)

src/inputoutput.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,65 @@ function toparam(sys, ctrls::AbstractVector)
239239
end
240240
ODESystem(eqs, name = nameof(sys))
241241
end
242+
243+
function inputs_to_parameters!(state::TransformationState)
244+
@unpack structure, fullvars, sys = state
245+
@unpack var_to_diff, graph, solvable_graph = structure
246+
@assert solvable_graph === nothing
247+
248+
inputs = BitSet()
249+
var_reidx = zeros(Int, length(fullvars))
250+
ninputs = 0
251+
nvar = 0
252+
new_parameters = []
253+
input_to_parameters = Dict()
254+
new_fullvars = []
255+
for (i, v) in enumerate(fullvars)
256+
if isinput(v) && !is_bound(sys, v)
257+
if var_to_diff[i] !== nothing
258+
error("Input $(fullvars[i]) is differentiated!")
259+
end
260+
push!(inputs, i)
261+
ninputs += 1
262+
var_reidx[i] = -1
263+
p = toparam(v)
264+
push!(new_parameters, p)
265+
input_to_parameters[v] = p
266+
else
267+
nvar += 1
268+
var_reidx[i] = nvar
269+
push!(new_fullvars, v)
270+
end
271+
end
272+
ninputs == 0 && return state
273+
274+
nvars = ndsts(graph) - ninputs
275+
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
276+
277+
for ie in 1:nsrcs(graph)
278+
for iv in 𝑠neighbors(graph, ie)
279+
iv = var_reidx[iv]
280+
iv > 0 || continue
281+
add_edge!(new_graph, ie, iv)
282+
end
283+
end
284+
285+
new_var_to_diff = DiffGraph(nvars, true)
286+
for (i, v) in enumerate(var_to_diff)
287+
new_i = var_reidx[i]
288+
(new_i < 1 || v === nothing) && continue
289+
new_v = var_reidx[v]
290+
@assert new_v > 0
291+
new_var_to_diff[new_i] = new_v
292+
end
293+
@set! structure.var_to_diff = new_var_to_diff
294+
@set! structure.graph = new_graph
295+
296+
@set! sys.eqs = map(Base.Fix2(substitute, input_to_parameters), equations(sys))
297+
@set! sys.states = setdiff(states(sys), keys(input_to_parameters))
298+
@set! sys.ps = [parameters(sys); new_parameters]
299+
300+
@set! state.sys = sys
301+
@set! state.fullvars = new_fullvars
302+
@set! state.structure = structure
303+
end

src/structural_transformation/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ function find_eq_solvables!(state::TearingState, ieq; may_be_zero = false,
170170
to_rm = Int[]
171171
for j in 𝑠neighbors(graph, ieq)
172172
var = fullvars[j]
173-
isinput(var) && continue
173+
#isinput(var) && continue
174174
a, b, islinear = linear_expansion(term, var)
175175
a = unwrap(a)
176176
islinear || continue

src/systems/abstractsystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,8 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
10021002
sys = expand_connections(sys)
10031003
sys = alias_elimination(sys)
10041004
state = TearingState(sys)
1005+
state = inputs_to_parameters!(state)
1006+
sys = state.sys
10051007
check_consistency(state)
10061008
if sys isa ODESystem
10071009
sys = dae_order_lowering(dummy_derivative(sys, state))

test/input_output_handling.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,32 @@ D = Differential(tv)
3131
@test !is_bound(sys2, sys.u)
3232
@test !is_bound(sys2, sys2.sys.u)
3333

34+
fsys2 = flatten(sys2)
35+
@test is_bound(fsys2, sys.x)
36+
@test !is_bound(fsys2, sys.u)
37+
@test !is_bound(fsys2, sys2.sys.u)
38+
39+
3440
@test is_bound(sys3, sys.u) # I would like to write sys3.sys.u here but that's not how the variable is stored in the equations
3541
@test is_bound(sys3, sys.x)
3642

3743
@test is_bound(sys4, sys.u)
3844
@test !is_bound(sys4, u)
3945

46+
fsys4 = flatten(sys4)
47+
@test is_bound(fsys4, sys.u)
48+
@test !is_bound(fsys4, u)
49+
4050
@test isequal(inputs(sys), [u])
4151
@test isequal(inputs(sys2), [sys.u])
4252

4353
@test isempty(bound_inputs(sys))
4454
@test isequal(unbound_inputs(sys), [u])
4555

4656
@test isempty(bound_inputs(sys2))
57+
@test isempty(bound_inputs(fsys2))
4758
@test isequal(unbound_inputs(sys2), [sys.u])
59+
@test isequal(unbound_inputs(fsys2), [sys.u])
4860

4961
@test isequal(bound_inputs(sys3), [sys.u])
5062
@test isempty(unbound_inputs(sys3))
@@ -161,3 +173,9 @@ p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)
161173
x = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), dvs)
162174
u = [rand()]
163175
@test f[1](x, u, p, 1) == [u; 0; 0; 0]
176+
177+
@parameters t
178+
@variables x(t) u(t) [input=true]
179+
eqs = [Differential(t)(x) ~ u]
180+
@named sys = ODESystem(eqs, t)
181+
structural_simplify(sys)

test/reduction.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,9 @@ D = Differential(t)
240240

241241
eqs = [D(x) ~ σ * (y - x)
242242
D(y) ~ x *- z) - y + β
243-
0 ~ z - x + y
244243
0 ~ a + z
245244
u ~ z + a]
246245

247246
lorenz1 = ODESystem(eqs, t, name = :lorenz1)
248247
lorenz1_reduced = structural_simplify(lorenz1)
249-
@test z in Set(states(lorenz1_reduced))
248+
@test z in Set(parameters(lorenz1_reduced))

0 commit comments

Comments
 (0)