Skip to content

Commit b4f6b38

Browse files
authored
Merge pull request #1208 from SciML/myb/opt
Don't reduce inputs
2 parents 9ecbb77 + d5f3a26 commit b4f6b38

File tree

8 files changed

+42
-6
lines changed

8 files changed

+42
-6
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
5252
ParallelForm, SerialForm, MultithreadedForm, build_function,
5353
unflatten_long_ops, rhss, lhss, prettify_expr, gradient,
5454
jacobian, hessian, derivative, sparsejacobian, sparsehessian,
55-
substituter, scalarize
55+
substituter, scalarize, getparent
5656

5757
import DiffEqBase: @add_kwonly
5858

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ using ModelingToolkit
1616
using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Differential,
1717
states, equations, vars, Symbolic, diff2term, value,
1818
operation, arguments, Sym, Term, simplify, solve_for,
19-
isdiffeq, isdifferential, get_structure, get_iv, independent_variables,
19+
isdiffeq, isdifferential, isinput,
20+
get_structure, get_iv, independent_variables,
2021
get_structure, defaults, InvalidSystemException,
2122
ExtraEquationsSystemException,
2223
ExtraVariablesSystemException,

src/structural_transformation/utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ function find_solvables!(sys)
276276
term = value(eq.rhs - eq.lhs)
277277
for j in 𝑠neighbors(graph, i)
278278
isalgvar(s, j) || continue
279-
a, b, islinear = linear_expansion(term, fullvars[j])
279+
var = fullvars[j]
280+
isinput(var) && continue
281+
a, b, islinear = linear_expansion(term, var)
280282
a = unwrap(a)
281283
if islinear && (!(a isa Symbolic) && a isa Number && a != 0)
282284
add_edge!(solvable_graph, i, j)

src/systems/systemstructure.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using Symbolics: linear_expansion, unwrap
55
using SymbolicUtils: istree, operation, arguments, Symbolic
66
using ..ModelingToolkit
77
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
8-
value, InvalidSystemException, isdifferential, _iszero, isparameter, independent_variables
8+
value, InvalidSystemException, isdifferential, _iszero, isparameter,
9+
independent_variables, isinput
910
using ..BipartiteGraphs
1011
using LightGraphs
1112
using UnPack
@@ -230,7 +231,7 @@ function find_linear_equations(sys)
230231
var = fullvars[j]
231232
a, b, islinear = linear_expansion(term, var)
232233
a = unwrap(a)
233-
if islinear && !(a isa Symbolic) && a isa Number
234+
if islinear && !(a isa Symbolic) && a isa Number && !isinput(var)
234235
if a == 1 || a == -1
235236
a = convert(Integer, a)
236237
linear_term += a * var

src/utils.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,12 @@ isdiffeq(eq) = isdifferential(eq.lhs)
225225
isdifference(expr) = istree(expr) && operation(expr) isa Difference
226226
isdifferenceeq(eq) = isdifference(eq.lhs)
227227

228-
isvariable(x) = x isa Symbolic && hasmetadata(x, VariableSource)
228+
function isvariable(x)
229+
x isa Symbolic || return false
230+
p = getparent(x, nothing)
231+
p === nothing || (x = p)
232+
hasmetadata(x, VariableSource)
233+
end
229234

230235
vars(x::Sym; op=Differential) = Set([x])
231236
vars(exprs::Symbolic; op=Differential) = vars([exprs]; op=op)

src/variables.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ Symbolics.option_to_metadata_type(::Val{:description}) = VariableDescriptionType
1111
Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput
1212
Symbolics.option_to_metadata_type(::Val{:output}) = VariableOutput
1313

14+
function isvarkind(m, x)
15+
p = getparent(x, nothing)
16+
p === nothing || (x = p)
17+
getmetadata(x, m, false)
18+
end
19+
20+
isinput(x) = isvarkind(VariableInput, x)
21+
isoutput(x) = isvarkind(VariableOutput, x)
22+
1423
"""
1524
$(SIGNATURES)
1625

test/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ eqs = [
399399
@test isequal(@nonamespace(sys.y), unwrap(y))
400400
@test isequal(@nonamespace(sys.p), unwrap(p))
401401
@test_nowarn sys.x, sys.y, sys.p
402+
@test ModelingToolkit.isvariable(Symbolics.unwrap(x[1]))
402403

403404
# Mixed Difference Differential equations
404405
@parameters t a b c d

test/reduction.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,20 @@ dt = ModelingToolkit.value(ModelingToolkit.derivative(eq.rhs, sin(10t)))
281281
@test dv25 0.3
282282
@test ddv25 == 0.005
283283
@test dt == -0.1
284+
285+
# Don't reduce inputs
286+
@parameters t σ ρ β
287+
@variables x(t) y(t) z(t) [input=true] a(t) u(t) F(t)
288+
D = Differential(t)
289+
290+
eqs = [
291+
D(x) ~ σ*(y-x)
292+
D(y) ~ x*-z)-y + β
293+
0 ~ z - x + y
294+
0 ~ a + z
295+
u ~ z + a
296+
]
297+
298+
lorenz1 = ODESystem(eqs,t,name=:lorenz1)
299+
lorenz1_reduced = structural_simplify(lorenz1)
300+
@test z in Set(states(lorenz1_reduced))

0 commit comments

Comments
 (0)