Skip to content

Commit f380639

Browse files
authored
Merge pull request #618 from SciML/ys/makesym
More serious Symbol and makesym
2 parents f9e5c33 + 63c5c24 commit f380639

File tree

6 files changed

+56
-34
lines changed

6 files changed

+56
-34
lines changed

src/build_function.jl

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -458,27 +458,12 @@ end
458458

459459
vars_to_pairs(args) = vars_to_pairs(args[1],args[2])
460460
function vars_to_pairs(name,vs::AbstractArray)
461-
vs_names = [term_to_symbol(value(u)) for u vs]
462-
exs = [:($name[$i]) for (i, u) enumerate(vs)]
463-
vs_names,exs
461+
vs_names = [tosymbol(u) for u vs]
462+
exs = [:($name[$i]) for (i, u) enumerate(vs)]
463+
vs_names,exs
464464
end
465-
466-
function term_to_symbol(t::Term)
467-
if operation(t) isa Sym
468-
s = nameof(operation(t))
469-
else
470-
error("really?")
471-
end
472-
end
473-
474-
term_to_symbol(s::Sym) = nameof(s)
475-
476465
function vars_to_pairs(name,vs)
477-
[term_to_symbol(value(vs))], [name]
478-
end
479-
480-
function rm_calls_with_iv(expr)
481-
Rewriters.Prewalk(Rewriters.Chain([@rule((~f::(x->x isa Sym))(~t::(x->x isa Sym)) => Sym{symtype((~f)(~t))}((term_to_symbol(~f))))]))(value(expr))
466+
[tosymbol(vs)], [name]
482467
end
483468

484469
get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
@@ -494,7 +479,7 @@ function numbered_expr(O::Union{Term,Sym},args...;varordering = args[1],offset =
494479
end
495480
end
496481
end
497-
return Expr(:call, O isa Sym ? nameof(O) : Symbol(O.op),
482+
return Expr(:call, O isa Sym ? tosymbol(O, escape=false) : Symbol(O.op),
498483
[numbered_expr(x,args...;offset=offset,lhsname=lhsname,
499484
rhsnames=rhsnames,varordering=varordering) for x in O.args]...)
500485
end
@@ -504,7 +489,7 @@ function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1
504489

505490
varordering = value.(args[1])
506491
var = var_from_nested_derivative(de.lhs)[1]
507-
i = findfirst(x->isequal(x isa Sym ? term_to_symbol(x) : term_to_symbol(x.op),term_to_symbol(var)),varordering)
492+
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : x.op, escape=false), tosymbol(var, escape=false)),varordering)
508493
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
509494
varordering = varordering,
510495
lhsname = lhsname,

src/systems/control/controlsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ end
9191
ControlToExpr(@nospecialize(sys)) = ControlToExpr(sys,states(sys),controls(sys))
9292
function (f::ControlToExpr)(O::Term)
9393
res = if isa(O.op, Sym)
94-
any(isequal(O), f.states) && return O.op.name # dependent variables
95-
any(isequal(O), f.controls) && return O.op.name # control variables
94+
# normal variables and control variables
95+
(any(isequal(O), f.states) || any(isequal(O), f.controls)) && return tosymbol(O)
9696
build_expr(:call, Any[O.op.name; f.(O.args)])
9797
else
9898
build_expr(:call, Any[Symbol(O.op); f.(O.args)])

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ function calculate_tgrad(sys::AbstractODESystem;
1111
if simplify
1212
tgrad = ModelingToolkit.simplify.(notime_tgrad)
1313
end
14+
xs = states(sys)
15+
rule = Dict(map((x, xt) -> x=>xt, detime_dvs.(xs), xs))
16+
tgrad = substitute.(tgrad, Ref(rule))
1417
sys.tgrad[] = tgrad
1518
return tgrad
1619
end
@@ -41,7 +44,8 @@ ODEToExpr(@nospecialize(sys)) = ODEToExpr(sys,states(sys))
4144
(f::ODEToExpr)(O::Num) = f(value(O))
4245
function (f::ODEToExpr)(O::Term)
4346
if isa(O.op, Sym)
44-
any(isequal(O), f.states) && return O.op.name # dependent variables
47+
any(isequal(O), f.states) && return tosymbol(O)
48+
# dependent variables
4549
return build_expr(:call, Any[O.op.name; f.(O.args)])
4650
end
4751
return build_expr(:call, Any[O.op; f.(O.args)])
@@ -64,18 +68,49 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
6468
conv = ODEToExpr(sys), kwargs...)
6569
end
6670

67-
function makesym(t::Term{T}) where {T}
68-
t.op isa Sym && return makesym(t.op)
69-
t.op isa Differential && return Sym{T}(Symbol(nameof(makesym(t.args[1])), , nameof(makesym(t.op.x))))
71+
Base.Symbol(x::Union{Num,Symbolic}) = tosymbol(x)
72+
tosymbol(x; kwargs...) = x
73+
tosymbol(x::Sym; kwargs...) = nameof(x)
74+
tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...)
75+
76+
"""
77+
tosymbol(x::Union{Num,Symbolic}; states=nothing, escape=true) -> Symbol
78+
79+
Convert `x` to a symbol. `states` are the states of a system, and `escape`
80+
means if the target has escapes like `val"y⦗t⦘"`. If `escape` then it will only
81+
output `y` instead of `y⦗t⦘`.
82+
"""
83+
function tosymbol(t::Term; states=nothing, escape=true)
84+
if t.op isa Sym
85+
if states !== nothing && !(any(isequal(t), states))
86+
return nameof(t.op)
87+
end
88+
op = nameof(t.op)
89+
args = t.args
90+
elseif t.op isa Differential
91+
if !(t.args[1].op isa Sym)
92+
@goto err
93+
end
94+
op = Symbol(nameof(t.args[1].op),
95+
,
96+
tosymbol(t.op.x))
97+
args = t.args[1].args
98+
else
99+
@goto err
100+
end
101+
102+
return escape ? Symbol(op, "", join(args, ", "), "") : op
103+
@label err
70104
error("Cannot convert $t to a symbol")
71105
end
72-
makesym(t::Sym{T}) where {T} = t
73-
makesym(t::Sym{FnType{T, S}}) where {T,S} = Sym{S}(nameof(t))
106+
107+
makesym(t::Symbolic; kwargs...) = Sym{symtype(t)}(tosymbol(t; kwargs...))
108+
makesym(t::Num; kwargs...) = makesym(value(t); kwargs...)
74109

75110
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
76111
# optimization
77-
dvs′ = makesym.(value.(dvs))
78-
ps′ = makesym.(value.(ps))
112+
dvs′ = makesym.(value.(dvs), states=dvs)
113+
ps′ = makesym.(value.(ps), states=dvs)
79114

80115
sub = Dict(dvs .=> dvs′)
81116
# substitute x(t) by just x
@@ -90,7 +125,7 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify=true)
90125
M = zeros(length(eqs),length(eqs))
91126
for (i,eq) in enumerate(eqs)
92127
if eq.lhs isa Term && eq.lhs.op isa Differential
93-
j = findfirst(x->isequal(term_to_symbol(x),term_to_symbol(var_from_nested_derivative(eq.lhs)[1])),dvs)
128+
j = findfirst(x->isequal(tosymbol(x),tosymbol(var_from_nested_derivative(eq.lhs)[1])),dvs)
94129
M[i,j] = 1
95130
else
96131
eq.lhs == 0 || error("Only semi-explicit constant mass matrices are currently supported. Faulty equation: $eq.")

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ function states_to_sym(states::Set)
125125
Expr(:(=), _states_to_sym(O.lhs), _states_to_sym(O.rhs))
126126
elseif O isa Term
127127
if isa(O.op, Sym)
128-
O in states && return O.op.name # dependent variables
128+
O in states && return tosymbol(O)
129+
# dependent variables
129130
return build_expr(:call, Any[O.op.name; _states_to_sym.(O.args)])
130131
else
131132
return build_expr(:call, Any[O.op; _states_to_sym.(O.args)])

test/derivatives.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ t1 = ModelingToolkit.gradient(tmp, [x1, x2])
9797
@parameters t k
9898
@variables x(t)
9999
@derivatives D'~k
100-
@test ModelingToolkit.makesym(D(x).val).name === :xˍk
100+
@test ModelingToolkit.makesym(D(x).val).name === Symbol("xˍk⦗t⦘")
101101

102102
using ModelingToolkit
103103
@variables t x(t)

test/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ eqs = [D(x) ~ -A*x,
145145
D(y) ~ A*x - B*_x]
146146
de = ODESystem(eqs)
147147
@test begin
148+
local f
148149
f = eval(generate_function(de, [x,y], [A,B,C])[2])
149150
du = [0.0,0.0]
150151
f(du, [1.0,2.0], [1,2,3], 0.0)

0 commit comments

Comments
 (0)