Skip to content

Commit eef9b7a

Browse files
Merge pull request #625 from SciML/myb/diff
Restructure and make Symbol composable
2 parents 5e290b6 + a3a357f commit eef9b7a

File tree

4 files changed

+84
-70
lines changed

4 files changed

+84
-70
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -68,45 +68,6 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
6868
conv = ODEToExpr(sys), kwargs...)
6969
end
7070

71-
Base.Symbol(x::Union{Num,Symbolic}) = tosymbol(x)
72-
tosymbol(x; kwargs...) = x
73-
tosymbol(x::Sym; kwargs...) = nameof(x)
74-
tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...)
75-
76-
"""
77-
tosymbol(x::Union{Num,Symbolic}; states=nothing, escape=true) -> Symbol
78-
79-
Convert `x` to a symbol. `states` are the states of a system, and `escape`
80-
means if the target has escapes like `val"y⦗t⦘"`. If `escape` then it will only
81-
output `y` instead of `y⦗t⦘`.
82-
"""
83-
function tosymbol(t::Term; states=nothing, escape=true)
84-
if t.op isa Sym
85-
if states !== nothing && !(any(isequal(t), states))
86-
return nameof(t.op)
87-
end
88-
op = nameof(t.op)
89-
args = t.args
90-
elseif t.op isa Differential
91-
if !(t.args[1].op isa Sym)
92-
@goto err
93-
end
94-
op = Symbol(nameof(t.args[1].op),
95-
,
96-
tosymbol(t.op.x))
97-
args = t.args[1].args
98-
else
99-
@goto err
100-
end
101-
102-
return escape ? Symbol(op, "", join(args, ", "), "") : op
103-
@label err
104-
error("Cannot convert $t to a symbol")
105-
end
106-
107-
makesym(t::Symbolic; kwargs...) = Sym{symtype(t)}(tosymbol(t; kwargs...))
108-
makesym(t::Num; kwargs...) = makesym(value(t); kwargs...)
109-
11071
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
11172
# optimization
11273
dvs′ = makesym.(value.(dvs), states=dvs)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,3 @@
1-
function lower_varname(var::Term, idv, order)
2-
order == 0 && return var
3-
name = Symbol(nameof(var.op), , string(idv)^order)
4-
#name = Symbol(var.name, :ˍ, string(idv.name)^order)
5-
return Sym{symtype(var.op)}(name)(var.args[1])
6-
end
7-
8-
function lower_varname(t::Term, iv)
9-
var, order = var_from_nested_derivative(t)
10-
lower_varname(var, iv, order)
11-
end
12-
lower_varname(t::Sym, iv) = t
13-
14-
function flatten_differential(O::Term)
15-
@assert is_derivative(O) "invalid differential: $O"
16-
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
17-
(x, t, order) = flatten_differential(O.args[1])
18-
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
19-
return (x, t, order + 1)
20-
end
21-
221
"""
232
$(TYPEDSIGNATURES)
243
@@ -47,7 +26,7 @@ function ode_order_lowering(eqs, iv, states)
4726
# only save to the dict when we need to lower the order to save memory
4827
maxorder > get(var_order, var, 1) && (var_order[var] = maxorder)
4928
var′ = lower_varname(var, iv, maxorder - 1)
50-
rhs′ = rename_lower_order(eq.rhs)
29+
rhs′ = diff2term(eq.rhs)
5130
push!(diff_vars, var′)
5231
push!(diff_eqs, D(var′) ~ rhs′)
5332
end
@@ -68,12 +47,3 @@ function ode_order_lowering(eqs, iv, states)
6847
# we want to order the equations and variables to be `(diff, alge)`
6948
return (vcat(diff_eqs, alge_eqs), vcat(diff_vars, alge_vars))
7049
end
71-
72-
function rename_lower_order(O)
73-
isa(O, Term) || return O
74-
if is_derivative(O)
75-
(x, t, order) = flatten_differential(O)
76-
return lower_varname(x, t, order)
77-
end
78-
return Term(O.op, rename_lower_order.(O.args))
79-
end

src/utils.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,82 @@ Maps the variable to a variable (state).
155155
"""
156156
tovar(s::Sym{<:Parameter}) = Sym{symtype(s)}(s.name)
157157
tovar(s::Sym) = s
158+
159+
Base.Symbol(x::Union{Num,Symbolic}) = tosymbol(x)
160+
tosymbol(x; kwargs...) = x
161+
tosymbol(x::Sym; kwargs...) = nameof(x)
162+
tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...)
163+
164+
"""
165+
tosymbol(x::Union{Num,Symbolic}; states=nothing, escape=true) -> Symbol
166+
167+
Convert `x` to a symbol. `states` are the states of a system, and `escape`
168+
means if the target has escapes like `val"y⦗t⦘"`. If `escape` then it will only
169+
output `y` instead of `y⦗t⦘`.
170+
"""
171+
function tosymbol(t::Term; states=nothing, escape=true)
172+
if t.op isa Sym
173+
if states !== nothing && !(any(isequal(t), states))
174+
return nameof(t.op)
175+
end
176+
op = nameof(t.op)
177+
args = t.args
178+
elseif t.op isa Differential
179+
term = diff2term(t)
180+
op = Symbol(operation(term))
181+
args = arguments(term)
182+
else
183+
@goto err
184+
end
185+
186+
return escape ? Symbol(op, "", join(args, ", "), "") : op
187+
@label err
188+
error("Cannot convert $t to a symbol")
189+
end
190+
191+
makesym(t::Symbolic; kwargs...) = Sym{symtype(t)}(tosymbol(t; kwargs...))
192+
makesym(t::Num; kwargs...) = makesym(value(t); kwargs...)
193+
194+
function lower_varname(var::Term, idv, order)
195+
order == 0 && return var
196+
name = string(nameof(var.op))
197+
underscore = 'ˍ'
198+
idx = findlast(underscore, name)
199+
append = string(idv)^order
200+
if idx === nothing
201+
newname = Symbol(name, underscore, append)
202+
else
203+
nidx = nextind(name, idx)
204+
newname = Symbol(name[1:idx], name[nidx:end], append)
205+
end
206+
return Sym{symtype(var.op)}(newname)(var.args[1])
207+
end
208+
209+
function lower_varname(t::Term, iv)
210+
var, order = var_from_nested_derivative(t)
211+
lower_varname(var, iv, order)
212+
end
213+
lower_varname(t::Sym, iv) = t
214+
215+
function flatten_differential(O::Term)
216+
@assert is_derivative(O) "invalid differential: $O"
217+
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
218+
(x, t, order) = flatten_differential(O.args[1])
219+
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
220+
return (x, t, order + 1)
221+
end
222+
223+
"""
224+
diff2term(x::Term) -> Term
225+
diff2term(x) -> x
226+
227+
diff2term(D(D(x(t)))) -> xˍtt(t)
228+
"""
229+
function diff2term(O)
230+
isa(O, Term) || return O
231+
if is_derivative(O)
232+
(x, t, order) = flatten_differential(O)
233+
return lower_varname(x, t, order)
234+
end
235+
return Term(O.op, diff2term.(O.args))
236+
end

test/derivatives.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ using Test
44
# Derivatives
55
@parameters t σ ρ β
66
@variables x y z
7+
@variables uu(t) uuˍt(t)
78
@derivatives D'~t D2''~t Dx'~x
89

10+
@test Symbol(D(D(uu))) === Symbol("uuˍtt⦗t⦘")
11+
@test Symbol(D(uuˍt)) === Symbol(D(D(uu)))
12+
913
test_equal(a, b) = @test isequal(simplify(a), simplify(b))
1014

1115
@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t))

0 commit comments

Comments
 (0)