Skip to content

Commit 828d461

Browse files
committed
some more genericness
1 parent 0d70c3b commit 828d461

File tree

4 files changed

+25
-17
lines changed

4 files changed

+25
-17
lines changed

src/build_function.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ end
476476

477477
get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
478478

479-
function numbered_expr(O::Union{Term,Sym},args...;varordering = args[1],offset = 0,
479+
function numbered_expr(O::Symbolic,args...;varordering = args[1],offset = 0,
480480
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
481481
O = value(O)
482482
if O isa Sym || isa(operation(O), Sym)

src/systems/abstractsystem.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,13 @@ Generate a function to evaluate the system's equations.
117117
"""
118118
function generate_function end
119119

120-
getname(x::Sym) = nameof(x)
121-
getname(t::Term) = operation(t) isa Sym ? getname(operation(t)) : error("Cannot get name of $t")
120+
function getname(t)
121+
if istree(t)
122+
operation(t) isa Sym ? getname(operation(t)) : error("Cannot get name of $t")
123+
else
124+
nameof(t)
125+
end
126+
end
122127

123128
function Base.getproperty(sys::AbstractSystem, name::Symbol)
124129

@@ -188,14 +193,17 @@ function namespace_expr(O::Sym,name,ivname)
188193
O.name == ivname ? O : rename(O,renamespace(name,O.name))
189194
end
190195

191-
function namespace_expr(O::Term{T},name,ivname) where {T}
192-
if operation(O) isa Sym
193-
Term{T}(rename(operation(O),renamespace(name,operation(O).name)),namespace_expr.(arguments(O),name,ivname))
196+
function namespace_expr(O,name,ivname) where {T}
197+
if istree(O)
198+
if operation(O) isa Sym
199+
Term{T}(rename(operation(O),renamespace(name,operation(O).name)),namespace_expr.(arguments(O),name,ivname))
200+
else
201+
similarterm(O,operation(O),namespace_expr.(arguments(O),name,ivname))
202+
end
194203
else
195-
Term{T}(operation(O),namespace_expr.(arguments(O),name,ivname))
204+
O
196205
end
197206
end
198-
namespace_expr(O,name,ivname) = O
199207

200208
independent_variable(sys::AbstractSystem) = sys.iv
201209
function states(sys::AbstractSystem)
@@ -277,12 +285,12 @@ struct AbstractSysToExpr
277285
states::Vector
278286
end
279287
AbstractSysToExpr(sys) = AbstractSysToExpr(sys,states(sys))
280-
function (f::AbstractSysToExpr)(O::Term)
288+
function (f::AbstractSysToExpr)(O)
289+
!istree(O) && return toexpr(O)
281290
any(isequal(O), f.states) && return operation(O).name # variables
282291
if isa(operation(O), Sym)
283292
return build_expr(:call, Any[operation(O).name; f.(arguments(O))])
284293
end
285294
return build_expr(:call, Any[operation(O); f.(arguments(O))])
286295
end
287-
(f::AbstractSysToExpr)(x) = toexpr(x)
288296

src/systems/control/controlsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ sys = ControlSystem(loss,eqs,t,[x,v],[u],[])
4747
"""
4848
struct ControlSystem <: AbstractControlSystem
4949
"""The Loss function"""
50-
loss::Term
50+
loss::Any
5151
"""The ODEs defining the system."""
5252
eqs::Vector{Equation}
5353
"""Independent variable."""
@@ -89,7 +89,8 @@ struct ControlToExpr
8989
controls::Vector
9090
end
9191
ControlToExpr(@nospecialize(sys)) = ControlToExpr(sys,states(sys),controls(sys))
92-
function (f::ControlToExpr)(O::Term)
92+
function (f::ControlToExpr)(O)
93+
!istree(O) && return O
9394
res = if isa(operation(O), Sym)
9495
# normal variables and control variables
9596
(any(isequal(O), f.states) || any(isequal(O), f.controls)) && return tosymbol(O)
@@ -99,7 +100,6 @@ function (f::ControlToExpr)(O::Term)
99100
end
100101
end
101102
(f::ControlToExpr)(x::Sym) = x.name
102-
(f::ControlToExpr)(x) = x
103103

104104
function constructRadauIIA5(T::Type = Float64)
105105
sq6 = sqrt(6)

src/systems/optimization/optimizationsystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ end
6464
function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys);
6565
sparse = false, kwargs...)
6666
if sparse
67-
hess = sparsehessian(equations(sys),[dv() for dv in states(sys)])
67+
hess = sparsehessian(equations(sys),states(sys))
6868
else
6969
hess = calculate_hessian(sys)
7070
end
@@ -77,11 +77,11 @@ function generate_function(sys::OptimizationSystem, vs = states(sys), ps = param
7777
conv = AbstractSysToExpr(sys),kwargs...)
7878
end
7979

80-
equations(sys::OptimizationSystem) = isempty(sys.systems) ? operation(sys) : operation(sys) + reduce(+,namespace_expr.(sys.systems))
81-
namespace_expr(sys::OptimizationSystem) = namespace_expr(operation(sys),sys.name,nothing)
80+
equations(sys::OptimizationSystem) = isempty(sys.systems) ? sys.op : sys.op + reduce(+,namespace_expr.(sys.systems))
81+
namespace_expr(sys::OptimizationSystem) = namespace_expr(sys.op,sys.name,nothing)
8282

8383
hessian_sparsity(sys::OptimizationSystem) =
84-
hessian_sparsity(operation(sys), states(sys))
84+
hessian_sparsity(sys.op, states(sys))
8585

8686
struct AutoModelingToolkit <: DiffEqBase.AbstractADType end
8787

0 commit comments

Comments
 (0)