Skip to content

Commit 8b4d923

Browse files
committed
move util function to src/utils.jl and handle high order diff to symbol
1 parent 5e290b6 commit 8b4d923

File tree

3 files changed

+65
-70
lines changed

3 files changed

+65
-70
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -68,45 +68,6 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
6868
conv = ODEToExpr(sys), kwargs...)
6969
end
7070

71-
Base.Symbol(x::Union{Num,Symbolic}) = tosymbol(x)
72-
tosymbol(x; kwargs...) = x
73-
tosymbol(x::Sym; kwargs...) = nameof(x)
74-
tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...)
75-
76-
"""
77-
tosymbol(x::Union{Num,Symbolic}; states=nothing, escape=true) -> Symbol
78-
79-
Convert `x` to a symbol. `states` are the states of a system, and `escape`
80-
means if the target has escapes like `val"y⦗t⦘"`. If `escape` then it will only
81-
output `y` instead of `y⦗t⦘`.
82-
"""
83-
function tosymbol(t::Term; states=nothing, escape=true)
84-
if t.op isa Sym
85-
if states !== nothing && !(any(isequal(t), states))
86-
return nameof(t.op)
87-
end
88-
op = nameof(t.op)
89-
args = t.args
90-
elseif t.op isa Differential
91-
if !(t.args[1].op isa Sym)
92-
@goto err
93-
end
94-
op = Symbol(nameof(t.args[1].op),
95-
,
96-
tosymbol(t.op.x))
97-
args = t.args[1].args
98-
else
99-
@goto err
100-
end
101-
102-
return escape ? Symbol(op, "", join(args, ", "), "") : op
103-
@label err
104-
error("Cannot convert $t to a symbol")
105-
end
106-
107-
makesym(t::Symbolic; kwargs...) = Sym{symtype(t)}(tosymbol(t; kwargs...))
108-
makesym(t::Num; kwargs...) = makesym(value(t); kwargs...)
109-
11071
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
11172
# optimization
11273
dvs′ = makesym.(value.(dvs), states=dvs)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,3 @@
1-
function lower_varname(var::Term, idv, order)
2-
order == 0 && return var
3-
name = Symbol(nameof(var.op), , string(idv)^order)
4-
#name = Symbol(var.name, :ˍ, string(idv.name)^order)
5-
return Sym{symtype(var.op)}(name)(var.args[1])
6-
end
7-
8-
function lower_varname(t::Term, iv)
9-
var, order = var_from_nested_derivative(t)
10-
lower_varname(var, iv, order)
11-
end
12-
lower_varname(t::Sym, iv) = t
13-
14-
function flatten_differential(O::Term)
15-
@assert is_derivative(O) "invalid differential: $O"
16-
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
17-
(x, t, order) = flatten_differential(O.args[1])
18-
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
19-
return (x, t, order + 1)
20-
end
21-
221
"""
232
$(TYPEDSIGNATURES)
243
@@ -47,7 +26,7 @@ function ode_order_lowering(eqs, iv, states)
4726
# only save to the dict when we need to lower the order to save memory
4827
maxorder > get(var_order, var, 1) && (var_order[var] = maxorder)
4928
var′ = lower_varname(var, iv, maxorder - 1)
50-
rhs′ = rename_lower_order(eq.rhs)
29+
rhs′ = diff2symbol(eq.rhs)
5130
push!(diff_vars, var′)
5231
push!(diff_eqs, D(var′) ~ rhs′)
5332
end
@@ -68,12 +47,3 @@ function ode_order_lowering(eqs, iv, states)
6847
# we want to order the equations and variables to be `(diff, alge)`
6948
return (vcat(diff_eqs, alge_eqs), vcat(diff_vars, alge_vars))
7049
end
71-
72-
function rename_lower_order(O)
73-
isa(O, Term) || return O
74-
if is_derivative(O)
75-
(x, t, order) = flatten_differential(O)
76-
return lower_varname(x, t, order)
77-
end
78-
return Term(O.op, rename_lower_order.(O.args))
79-
end

src/utils.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,67 @@ Maps the variable to a variable (state).
155155
"""
156156
tovar(s::Sym{<:Parameter}) = Sym{symtype(s)}(s.name)
157157
tovar(s::Sym) = s
158+
159+
Base.Symbol(x::Union{Num,Symbolic}) = tosymbol(x)
160+
tosymbol(x; kwargs...) = x
161+
tosymbol(x::Sym; kwargs...) = nameof(x)
162+
tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...)
163+
164+
"""
165+
tosymbol(x::Union{Num,Symbolic}; states=nothing, escape=true) -> Symbol
166+
167+
Convert `x` to a symbol. `states` are the states of a system, and `escape`
168+
means if the target has escapes like `val"y⦗t⦘"`. If `escape` then it will only
169+
output `y` instead of `y⦗t⦘`.
170+
"""
171+
function tosymbol(t::Term; states=nothing, escape=true)
172+
if t.op isa Sym
173+
if states !== nothing && !(any(isequal(t), states))
174+
return nameof(t.op)
175+
end
176+
op = nameof(t.op)
177+
args = t.args
178+
elseif t.op isa Differential
179+
term = diff2symbol(t)
180+
op = Symbol(operation(term))
181+
args = arguments(term)
182+
else
183+
@goto err
184+
end
185+
186+
return escape ? Symbol(op, "", join(args, ", "), "") : op
187+
@label err
188+
error("Cannot convert $t to a symbol")
189+
end
190+
191+
makesym(t::Symbolic; kwargs...) = Sym{symtype(t)}(tosymbol(t; kwargs...))
192+
makesym(t::Num; kwargs...) = makesym(value(t); kwargs...)
193+
194+
function lower_varname(var::Term, idv, order)
195+
order == 0 && return var
196+
name = Symbol(nameof(var.op), , string(idv)^order)
197+
return Sym{symtype(var.op)}(name)(var.args[1])
198+
end
199+
200+
function lower_varname(t::Term, iv)
201+
var, order = var_from_nested_derivative(t)
202+
lower_varname(var, iv, order)
203+
end
204+
lower_varname(t::Sym, iv) = t
205+
206+
function flatten_differential(O::Term)
207+
@assert is_derivative(O) "invalid differential: $O"
208+
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
209+
(x, t, order) = flatten_differential(O.args[1])
210+
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
211+
return (x, t, order + 1)
212+
end
213+
214+
function diff2symbol(O)
215+
isa(O, Term) || return O
216+
if is_derivative(O)
217+
(x, t, order) = flatten_differential(O)
218+
return lower_varname(x, t, order)
219+
end
220+
return Term(O.op, diff2symbol.(O.args))
221+
end

0 commit comments

Comments
 (0)