Skip to content

Commit 1992424

Browse files
committed
Refactor code
1 parent 16449ae commit 1992424

File tree

3 files changed

+21
-67
lines changed

3 files changed

+21
-67
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,7 @@ function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys), ps
7878
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
7979
end
8080

81-
@noinline function throw_invalid_derivative(dervar, eq)
82-
msg = "The derivative variable must be isolated to the left-hand " *
83-
"side of the equation like `$dervar ~ ...`.\n Got $eq."
84-
throw(InvalidSystemException(msg))
85-
end
86-
87-
function check_derivative_variables(eq, expr=eq.rhs)
88-
istree(expr) || return nothing
89-
if operation(expr) isa Differential
90-
throw_invalid_derivative(expr, eq)
91-
end
92-
foreach(Base.Fix1(check_derivative_variables, eq), arguments(expr))
93-
end
81+
check_derivative_variables(eq) = check_operator_variables(eq, Differential)
9482

9583
function generate_function(
9684
sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -124,60 +124,7 @@ end
124124
isdifference(expr) = istree(expr) && operation(expr) isa Difference
125125
isdifferenceeq(eq) = isdifference(eq.lhs)
126126

127-
difference_vars(x::Sym) = Set([x])
128-
difference_vars(exprs::Symbolic) = difference_vars([exprs])
129-
difference_vars(exprs) = foldl(difference_vars!, exprs; init = Set())
130-
difference_vars!(difference_vars, eq::Equation) = (difference_vars!(difference_vars, eq.lhs); difference_vars!(difference_vars, eq.rhs); difference_vars)
131-
function difference_vars!(difference_vars, O)
132-
if isa(O, Sym)
133-
return push!(difference_vars, O)
134-
end
135-
!istree(O) && return difference_vars
136-
137-
operation(O) isa Difference && return push!(difference_vars, O)
138-
139-
if operation(O) === (getindex) &&
140-
first(arguments(O)) isa Symbolic
141-
142-
return push!(difference_vars, O)
143-
end
144-
145-
symtype(operation(O)) <: FnType && push!(difference_vars, O)
146-
for arg in arguments(O)
147-
difference_vars!(difference_vars, arg)
148-
end
149-
150-
return difference_vars
151-
end
152-
153-
function collect_difference_variables(sys::DiscreteSystem)
154-
eqs = equations(sys)
155-
vars = Set()
156-
difference_vars = Set()
157-
for eq in eqs
158-
difference_vars!(vars, eq)
159-
for v in vars
160-
isdifference(v) || continue
161-
push!(difference_vars, arguments(v)[1])
162-
end
163-
empty!(vars)
164-
end
165-
return difference_vars
166-
end
167-
168-
@noinline function throw_invalid_difference(difvar, eq)
169-
msg = "The difference variable must be isolated to the left-hand " *
170-
"side of the equation like `$difvar ~ ...`.\n Got $eq."
171-
throw(InvalidSystemException(msg))
172-
end
173-
174-
function check_difference_variables(eq, expr=eq.rhs)
175-
istree(expr) || return nothing
176-
if operation(expr) isa Difference
177-
throw_invalid_difference(expr, eq)
178-
end
179-
foreach(Base.Fix1(check_difference_variables, eq), arguments(expr))
180-
end
127+
check_difference_variables(eq) = check_operator_variables(eq, Difference)
181128

182129
function generate_function(
183130
sys::DiscreteSystem, dvs = states(sys), ps = parameters(sys);

src/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,22 @@ function collect_defaults!(defs, vars)
166166
end
167167
return defs
168168
end
169+
170+
@noinline function throw_invalid_operator(opvar, eq, op::Type)
171+
if op === Difference
172+
optext = "difference"
173+
elseif op === Differential
174+
optext="derivative"
175+
end
176+
msg = "The $optext variable must be isolated to the left-hand " *
177+
"side of the equation like `$opvar ~ ...`.\n Got $eq."
178+
throw(InvalidSystemException(msg))
179+
end
180+
181+
function check_operator_variables(eq, op::Type, expr=eq.rhs)
182+
istree(expr) || return nothing
183+
if operation(expr) isa op
184+
throw_invalid_operator(expr, eq, op)
185+
end
186+
foreach(expr -> check_operator_variables(eq, op, expr), arguments(expr))
187+
end

0 commit comments

Comments
 (0)