Skip to content

Commit 135821e

Browse files
authored
Merge pull request #720 from SciML/s/addmulpow
WIP: upgrade to SymbolicUtils w/ fast terms
2 parents 17d02bb + 8f430f9 commit 135821e

26 files changed

+281
-189
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ RuntimeGeneratedFunctions = "0.4, 0.5"
5252
SafeTestsets = "0.0.1"
5353
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
5454
StaticArrays = "0.10, 0.11, 0.12, 1.0"
55-
SymbolicUtils = "0.6"
55+
SymbolicUtils = "0.7"
5656
TreeViews = "0.3"
5757
UnPack = "0.1, 1.0"
5858
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
2323
using RecursiveArrayTools
2424

2525
import SymbolicUtils
26-
import SymbolicUtils: Term, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
26+
import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
27+
28+
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint
2729

2830
using LinearAlgebra: LU, BlasInt
2931

@@ -72,13 +74,23 @@ Base.show(io::IO, n::Num) = show_numwrap[] ? print(io, :(Num($(value(n))))) : Ba
7274

7375
Base.promote_rule(::Type{<:Number}, ::Type{<:Num}) = Num
7476
Base.promote_rule(::Type{<:Symbolic{<:Number}}, ::Type{<:Num}) = Num
75-
Base.getproperty(t::Term, f::Symbol) = f === :op ? operation(t) : f === :args ? arguments(t) : getfield(t, f)
77+
function Base.getproperty(t::Union{Add, Mul, Pow, Term}, f::Symbol)
78+
if f === :op
79+
Base.depwarn("`x.op` is deprecated, use `operation(x)` instead", :getproperty, force=true)
80+
operation(t)
81+
elseif f === :args
82+
Base.depwarn("`x.args` is deprecated, use `arguments(x)` instead", :getproperty, force=true)
83+
arguments(t)
84+
else
85+
getfield(t, f)
86+
end
87+
end
7688
<(s::Num, x) = value(s) <value(x)
7789
<(s, x::Num) = value(s) <value(x)
7890
<(s::Num, x::Num) = value(s) <value(x)
7991

8092
for T in (Integer, Rational)
81-
@eval Base.:(^)(n::Num, i::$T) = Num(Term{symtype(n)}(^, [value(n),i]))
93+
@eval Base.:(^)(n::Num, i::$T) = Num(value(n)^i)
8294
end
8395

8496
macro num_method(f, expr, Ts=nothing)

src/build_function.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,28 +476,28 @@ end
476476

477477
get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
478478

479-
function numbered_expr(O::Union{Term,Sym},args...;varordering = args[1],offset = 0,
479+
function numbered_expr(O::Symbolic,args...;varordering = args[1],offset = 0,
480480
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
481481
O = value(O)
482-
if O isa Sym || isa(O.op, Sym)
482+
if O isa Sym || isa(operation(O), Sym)
483483
for j in 1:length(args)
484484
i = get_varnumber(O,args[j])
485485
if i !== nothing
486486
return :($(rhsnames[j])[$(i+offset)])
487487
end
488488
end
489489
end
490-
return Expr(:call, O isa Sym ? tosymbol(O, escape=false) : Symbol(O.op),
490+
return Expr(:call, O isa Sym ? tosymbol(O, escape=false) : Symbol(operation(O)),
491491
[numbered_expr(x,args...;offset=offset,lhsname=lhsname,
492-
rhsnames=rhsnames,varordering=varordering) for x in O.args]...)
492+
rhsnames=rhsnames,varordering=varordering) for x in arguments(O)]...)
493493
end
494494

495495
function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1],
496496
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)
497497

498498
varordering = value.(args[1])
499499
var = var_from_nested_derivative(de.lhs)[1]
500-
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : x.op, escape=false), tosymbol(var, escape=false)),varordering)
500+
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : operation(x), escape=false), tosymbol(var, escape=false)),varordering)
501501
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
502502
varordering = varordering,
503503
lhsname = lhsname,

src/context_dsl.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
import SymbolicUtils: symtype
1+
import SymbolicUtils: symtype, term
22
struct Parameter{T} end
33

44
isparameter(x) = false
55
isparameter(::Sym{<:Parameter}) = true
66
isparameter(::Sym{<:FnType{<:Any, <:Parameter}}) = true
77

8+
SymbolicUtils.@number_methods(Sym{Parameter{Real}},
9+
term(f, a),
10+
term(f, a, b), skipbasics)
11+
812
SymbolicUtils.symtype(s::Symbolic{Parameter{T}}) where T = T
913
SymbolicUtils.similarterm(t::Term{T}, f, args) where {T<:Parameter} = Term{T}(f, args)
1014

src/differentials.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
3535
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
3636

3737
_isfalse(occ::Bool) = occ === false
38-
_isfalse(occ::Term) = _isfalse(occ.op)
38+
_isfalse(occ::Term) = _isfalse(operation(occ))
3939

40-
function occursin_info(x, expr::Term)
40+
function occursin_info(x, expr)
41+
!istree(expr) && return false
4142
if isequal(x, expr)
4243
true
4344
else
44-
args = map(a->occursin_info(x, a), expr.args)
45+
args = map(a->occursin_info(x, a), arguments(expr))
4546
if all(_isfalse, args)
4647
return false
4748
end
@@ -52,66 +53,65 @@ function occursin_info(x, expr::Sym)
5253
isequal(x, expr)
5354
end
5455

55-
hasderiv(O::Term) = O.op isa Differential || any(hasderiv, O.args)
56-
hasderiv(O) = false
57-
58-
occursin_info(x, y) = false
56+
function hasderiv(O)
57+
istree(O) ? operation(O) isa Differential || any(hasderiv, arguments(O)) : false
58+
end
5959
"""
6060
$(SIGNATURES)
6161
6262
TODO
6363
"""
64-
function expand_derivatives(O::Term, simplify=true; occurances=nothing)
65-
if isa(O.op, Differential)
66-
@assert length(O.args) == 1
67-
arg = expand_derivatives(O.args[1], false)
64+
function expand_derivatives(O::Symbolic, simplify=false; occurances=nothing)
65+
if istree(O) && isa(operation(O), Differential)
66+
@assert length(arguments(O)) == 1
67+
arg = expand_derivatives(arguments(O)[1], false)
6868

6969
if occurances == nothing
70-
occurances = occursin_info(O.op.x, arg)
70+
occurances = occursin_info(operation(O).x, arg)
7171
end
7272

7373
_isfalse(occurances) && return 0
7474
occurances isa Bool && return 1 # means it's a `true`
7575

76-
(D, o) = (O.op, arg)
76+
D = operation(O)
7777

78-
if !isa(o, Term)
79-
return O # Cannot expand
80-
elseif isa(o.op, Sym)
81-
return O # Cannot expand
82-
elseif isa(o.op, Differential)
78+
if !istree(arg)
79+
return D(arg) # Cannot expand
80+
elseif isa(operation(arg), Sym)
81+
return D(arg) # Cannot expand
82+
elseif isa(operation(arg), Differential)
8383
# The recursive expand_derivatives was not able to remove
8484
# a nested Differential. We can attempt to differentiate the
8585
# inner expression wrt to the outer iv. And leave the
8686
# unexpandable Differential outside.
87-
if isequal(o.op.x, D.x)
88-
return O
87+
if isequal(operation(arg).x, D.x)
88+
return D(arg)
8989
else
90-
inner = expand_derivatives(D(o.args[1]), false)
90+
inner = expand_derivatives(D(arguments(arg)[1]), false)
9191
# if the inner expression is not expandable either, return
92-
if inner isa Term && operation(inner) isa Differential
93-
return O
92+
if istree(inner) && operation(inner) isa Differential
93+
return D(arg)
9494
else
95-
return expand_derivatives(o.op(inner), simplify)
95+
return expand_derivatives(operation(arg)(inner), simplify)
9696
end
9797
end
9898
end
9999

100-
l = length(o.args)
100+
l = length(arguments(arg))
101101
exprs = []
102102
c = 0
103103

104104
for i in 1:l
105-
t2 = expand_derivatives(D(o.args[i]),false, occurances=occurances.args[i])
105+
t2 = expand_derivatives(D(arguments(arg)[i]),false, occurances=arguments(occurances)[i])
106106

107107
x = if _iszero(t2)
108108
t2
109109
elseif _isone(t2)
110-
d = derivative_idx(o, i)
111-
d isa NoDeriv ? D(o) : d
110+
d = derivative_idx(arg, i)
111+
d isa NoDeriv ? D(arg) : d
112112
else
113-
t1 = derivative_idx(o, i)
114-
t1 = t1 isa NoDeriv ? D(o) : t1
113+
t1 = derivative_idx(arg, i)
114+
t1 = t1 isa NoDeriv ? D(arg) : t1
115115
make_operation(*, [t1, t2])
116116
end
117117

@@ -136,8 +136,8 @@ function expand_derivatives(O::Term, simplify=true; occurances=nothing)
136136
elseif !hasderiv(O)
137137
return O
138138
else
139-
args = map(a->expand_derivatives(a, false), O.args)
140-
O1 = make_operation(O.op, args)
139+
args = map(a->expand_derivatives(a, false), arguments(O))
140+
O1 = make_operation(operation(O), args)
141141
return simplify ? SymbolicUtils.simplify(O1) : O1
142142
end
143143
end
@@ -176,7 +176,7 @@ chain rule is not applied:
176176
julia> myop = sin(x) * y^2
177177
sin(x()) * y() ^ 2
178178
179-
julia> typeof(myop.op) # Op is multiplication function
179+
julia> typeof(operation(myop)) # Op is multiplication function
180180
typeof(*)
181181
182182
julia> ModelingToolkit.derivative_idx(myop, 1) # wrt. sin(x)
@@ -187,7 +187,9 @@ sin(x())
187187
```
188188
"""
189189
derivative_idx(O::Any, ::Any) = 0
190-
derivative_idx(O::Term, idx) = derivative(O.op, (O.args...,), Val(idx))
190+
function derivative_idx(O::Symbolic, idx)
191+
istree(O) ? derivative(operation(O), (arguments(O)...,), Val(idx)) : 0
192+
end
191193

192194
# Indicate that no derivative is defined.
193195
struct NoDeriv

src/direct.jl

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ let
128128

129129
_scalar = one(TermCombination)
130130

131-
linearity_propagator = [
131+
simterm(t, f, args) = Term{Any}(f, args)
132+
linearity_rules = [
132133
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
133134
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
134135
@rule (~f)(~x::(!isidx)) => _scalar
@@ -146,7 +147,8 @@ let
146147
else
147148
error("Function of unknown linearity used: ", ~f)
148149
end
149-
end] |> Rewriters.Chain |> Rewriters.Postwalk |> Rewriters.Fixpoint
150+
end]
151+
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=simterm))
150152

151153
global hessian_sparsity
152154

@@ -164,9 +166,9 @@ let
164166
u = map(value, u)
165167
idx(i) = TermCombination(Set([Dict(i=>1)]))
166168
dict = Dict(u .=> idx.(1:length(u)))
167-
found = []
168-
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x)(f)
169-
_sparse(linearity_propagator(f), length(u))
169+
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=simterm)(f)
170+
lp = linearity_propagator(f)
171+
_sparse(lp, length(u))
170172
end
171173
end
172174

@@ -213,28 +215,60 @@ function sparsehessian(O, vars::AbstractVector; simplify = true)
213215
return H
214216
end
215217

216-
function toexpr(O::Term)
217-
if isa(O.op, Differential)
218-
return :(derivative($(toexpr(O.args[1])),$(toexpr(O.op.x))))
219-
elseif isa(O.op, Sym)
220-
isempty(O.args) && return O.op.name
221-
return Expr(:call, toexpr(O.op), toexpr.(O.args)...)
222-
end
223-
if O.op === (^)
224-
if length(O.args) > 1 && O.args[2] isa Number && O.args[2] < 0
225-
return Expr(:call, ^, Expr(:call, inv, toexpr(O.args[1])), -(O.args[2]))
226-
end
227-
end
228-
return Expr(:call, O.op, toexpr.(O.args)...)
218+
"""
219+
toexpr(O::Union{Symbolics,Num,Equation,AbstractArray}; canonicalize=true) -> Expr
220+
221+
Convert `Symbolics` into `Expr`. If `canonicalize`, then we turn exprs like
222+
`x^(-n)` into `inv(x)^n` to avoid type error when evaluating.
223+
"""
224+
function toexpr(O; canonicalize=true)
225+
if canonicalize
226+
canonical, O = canonicalexpr(O)
227+
canonical && return O
228+
else
229+
!istree(O) && return O
230+
end
231+
232+
op = operation(O)
233+
args = arguments(O)
234+
if op isa Differential
235+
return :(derivative($(toexpr(args[1]; canonicalize=canonicalize)),$(toexpr(op.x; canonicalize=canonicalize))))
236+
elseif op isa Sym
237+
isempty(args) && return nameof(op)
238+
return Expr(:call, toexpr(op; canonicalize=canonicalize), toexpr(args; canonicalize=canonicalize)...)
239+
end
240+
return Expr(:call, op, toexpr(args; canonicalize=canonicalize)...)
241+
end
242+
toexpr(s::Sym; kw...) = nameof(s)
243+
244+
"""
245+
canonicalexpr(O) -> (canonical::Bool, expr)
246+
247+
Canonicalize `O`. Return `canonical` if `expr` is valid code to generate.
248+
"""
249+
function canonicalexpr(O)
250+
!istree(O) && return true, O
251+
op = operation(O)
252+
args = arguments(O)
253+
if op === (^)
254+
if length(args) == 2 && args[2] isa Number && args[2] < 0
255+
ex = toexpr(args[1])
256+
if args[2] == -1
257+
expr = Expr(:call, inv, ex)
258+
else
259+
expr = Expr(:call, ^, Expr(:call, inv, ex), -args[2])
260+
end
261+
return true, expr
262+
end
263+
end
264+
return false, O
229265
end
230-
toexpr(s::Sym) = nameof(s)
231-
toexpr(s) = s
232266

233-
function toexpr(eq::Equation)
234-
Expr(:(=), toexpr(eq.lhs), toexpr(eq.rhs))
267+
function toexpr(eq::Equation; kw...)
268+
Expr(:(=), toexpr(eq.lhs; kw...), toexpr(eq.rhs; kw...))
235269
end
236270

237-
toexpr(eq::AbstractArray) = toexpr.(eq)
238-
toexpr(x::Integer) = x
239-
toexpr(x::AbstractFloat) = x
240-
toexpr(x::Num) = toexpr(value(x))
271+
toexpr(eqs::AbstractArray; kw...) = map(eq->toexpr(eq; kw...), eqs)
272+
toexpr(x::Integer; kw...) = x
273+
toexpr(x::AbstractFloat; kw...) = x
274+
toexpr(x::Num; kw...) = toexpr(value(x); kw...)

src/latexify_recipes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...)
1111
# that latexify can deal with
1212

1313
rhs = getfield.(eqs, :rhs)
14-
rhs = prettify_expr.(toexpr.(rhs))
14+
rhs = prettify_expr.(toexpr(rhs; canonicalize=false))
1515
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
1616
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
1717
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs]
1818

1919
lhs = getfield.(eqs, :lhs)
20-
lhs = prettify_expr.(toexpr.(lhs))
20+
lhs = prettify_expr.(toexpr(lhs; canonicalize=false))
2121
lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs]
2222
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs]
2323
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs]

src/linearity.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ function Base.:(==)(comb1::TermCombination, comb2::TermCombination)
7878
end
7979
=#
8080

81+
# to make Mul and Add work
82+
Base.:*(::Number, comb::TermCombination) = comb
83+
function Base.:^(comb::TermCombination, ::Number)
84+
isone(comb) && return comb
85+
iszero(comb) && return _scalar
86+
return comb * comb
87+
end
88+
8189
function Base.:+(comb1::TermCombination, comb2::TermCombination)
8290
if isone(comb1) && !iszero(comb2)
8391
return comb2

0 commit comments

Comments
 (0)