Skip to content

Commit 7d6f362

Browse files
authored
Merge pull request #154 from JuliaSymbolics/s/fast-terms
WIP: constructor-level simplification
2 parents 2461789 + c62beb7 commit 7d6f362

File tree

11 files changed

+454
-93
lines changed

11 files changed

+454
-93
lines changed

docs/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ using SymbolicUtils # hide
2020

2121
{{doc Term Term type}}
2222

23+
{{doc Add Add type}}
24+
25+
{{doc Mul Mul type}}
26+
27+
{{doc Pow Pow type}}
28+
2329
{{doc promote_symtype promote_symtype fn}}
2430

2531
## Interfacing

docs/index.md

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@ where appropriate -->
1212

1313
The main features are:
1414

15-
- Symbols (`Sym`s) carry type information. ([read more](#symbolic_expressions))
16-
- Compound expressions composed of `Sym`s propagate type information. ([read more](#symbolic_expressions))
17-
- A flexible [rule-based rewriting language](#rule-based_rewriting) allowing liberal use of user defined matchers and rewriters.
15+
- Fast expressions
1816
- A [combinator library](#composing-rewriters) for making rewriters.
17+
- A [rule-based rewriting language](#rule-based_rewriting).
18+
- Type promotion:
19+
- Symbols (`Sym`s) carry type information. ([read more](#symbolic_expressions))
20+
- Compound expressions composed of `Sym`s propagate type information. ([read more](#symbolic_expressions))
1921
- Set of [simplification rules](#simplification). These can be remixed and extended for special purposes.
2022

2123

2224
## Table of contents
2325

2426
\tableofcontents <!-- you can use \toc as well -->
2527

26-
## Symbolic expressions
28+
## `Sym`s
2729

2830
First, let's use the `@syms` macro to create a few symbols.
2931

@@ -66,17 +68,6 @@ expr1 + expr2
6668
```
6769
\out{expr}
6870

69-
### Simplified printing
70-
71-
Tip: you can set `SymbolicUtils.show_simplified[] = true` to enable simplification on printing, or call `SymbolicUtils.showraw(expr)` to display an expression without simplification.
72-
In the REPL, if an expression was successfully simplified before printing, it will appear in yellow rather than white, as a visual cue that what you are looking at is not the exact datastructure.
73-
74-
```julia:showraw
75-
using SymbolicUtils: showraw
76-
77-
showraw(expr1 + expr2)
78-
```
79-
\out{showraw}
8071

8172
**Function-like symbols**
8273

@@ -106,6 +97,20 @@ g(2//5, g(1, β))
10697

10798
This works because `g` "returns" a `Real`.
10899

100+
101+
## Expression interface
102+
103+
Symbolic expressions are of type `Term{T}`, `Add{T}`, `Mul{T}` or `Pow{T}` and denote some function call where one or more arguments are themselves such expressions or `Sym`s.
104+
105+
All the expression types support the following:
106+
107+
- `istree(x)` -- always returns `true` denoting, `x` is not a leaf node like Sym or a literal.
108+
- `operation(x)` -- the function being called
109+
- `arguments(x)` -- a vector of arguments
110+
- `symtype(x)` -- the "inferred" type (`T`)
111+
112+
See more on the interface [here](/interface)
113+
109114
## Rule-based rewriting
110115

111116
Rewrite rules match and transform an expression. A rule is written using either the `@rule` macro or the `@acrule` macro.
@@ -151,7 +156,7 @@ Notice that there is a subexpression `(2 * w) + (2 * w)` that could be simplifie
151156

152157
### Predicates for matching
153158

154-
Matcher pattern may contain slot variables with attached predicates, written as `~x::f` where `f` is a function that takes a matched expression (a `Term` object a `Sym` or any Julia value that is in the expression tree) and returns a boolean value. Such a slot will be considered a match only if `f` returns true.
159+
Matcher pattern may contain slot variables with attached predicates, written as `~x::f` where `f` is a function that takes a matched expression and returns a boolean value. Such a slot will be considered a match only if `f` returns true.
155160

156161
Similarly `~~x::g` is a way of attaching a predicate `g` to a segment variable. In the case of segment variables `g` gets a vector of 0 or more expressions and must return a boolean value. If the same slot or segment variable appears twice in the matcher pattern, then at most one of the occurance should have a predicate.
157162

src/SymbolicUtils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ module SymbolicUtils
22

33
export @syms, term, showraw
44

5-
# Sym, Term and other types
5+
# Sym, Term,
6+
# Add, Mul and Pow
7+
using DataStructures
8+
import Base: +, -, *, /, \, ^
69
include("types.jl")
710

811
# Methods on symbolic objects
@@ -32,7 +35,6 @@ include("matchers.jl")
3235
# Convert to an efficient multi-variate polynomial representation
3336
import AbstractAlgebra.Generic: MPoly, PolynomialRing, ZZ, exponent_vector
3437
using AbstractAlgebra: ismonomial, symbols
35-
using DataStructures
3638
include("abstractalgebra.jl")
3739

3840
# Term ordering

src/methods.jl

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx,
44
besselj1, bessely0, bessely1, besselj, bessely, besseli,
55
besselk, hankelh1, hankelh2, polygamma, beta, logbeta
66

7-
const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
7+
const monadic = [deg2rad, rad2deg, transpose, conj, asind, log1p, acsch,
88
acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh,
99
abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, acotd,
1010
asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10,
@@ -14,10 +14,9 @@ const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
1414
trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi,
1515
airybiprime, besselj0, besselj1, bessely0, bessely1]
1616

17-
const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign,
17+
const diadic = [max, min, hypot, atan, mod, rem, copysign,
1818
besselj, bessely, besseli, besselk, hankelh1, hankelh2,
1919
polygamma, beta, logbeta]
20-
2120
const previously_declared_for = Set([])
2221

2322
# TODO: it's not possible to dispatch on the symtype! (only problem is Parameter{})
@@ -32,13 +31,17 @@ end
3231
islike(a, T) = symtype(a) <: T
3332

3433
# TODO: keep domains tighter than this
35-
function number_methods(T, rhs1, rhs2)
34+
function number_methods(T, rhs1, rhs2, options=nothing)
3635
exprs = []
3736

37+
skip_basics = !isnothing(options) ? options == :skipbasics : false
38+
basic_monadic = [-, +]
39+
basic_diadic = [+, -, *, /, \, ^]
40+
3841
rhs2 = :($assert_like(f, Number, a, b); $rhs2)
3942
rhs1 = :($assert_like(f, Number, a); $rhs1)
4043

41-
for f in diadic
44+
for f in (skip_basics ? diadic : vcat(basic_diadic, diadic))
4245
for S in previously_declared_for
4346
push!(exprs, quote
4447
(f::$(typeof(f)))(a::$T, b::$S) = $rhs2
@@ -58,25 +61,38 @@ function number_methods(T, rhs1, rhs2)
5861
push!(exprs, expr)
5962
end
6063

61-
for f in monadic
64+
for f in (skip_basics ? monadic : vcat(basic_monadic, monadic))
6265
push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1))
6366
end
6467
push!(exprs, :(push!($previously_declared_for, $T)))
6568
Expr(:block, exprs...)
6669
end
6770

68-
macro number_methods(T, rhs1, rhs2)
69-
number_methods(T, rhs1, rhs2) |> esc
71+
macro number_methods(T, rhs1, rhs2, options=nothing)
72+
number_methods(T, rhs1, rhs2, options) |> esc
7073
end
7174

72-
@number_methods(Sym, term(f, a), term(f, a, b))
73-
@number_methods(Term, term(f, a), term(f, a, b))
75+
@number_methods(Sym, term(f, a), term(f, a, b), skipbasics)
76+
@number_methods(Term, term(f, a), term(f, a, b), skipbasics)
77+
@number_methods(Add, term(f, a), term(f, a, b), skipbasics)
78+
@number_methods(Mul, term(f, a), term(f, a, b), skipbasics)
79+
@number_methods(Pow, term(f, a), term(f, a, b), skipbasics)
7480

7581
for f in diadic
7682
@eval promote_symtype(::$(typeof(f)),
7783
T::Type{<:Number},
7884
S::Type{<:Number}) = promote_type(T, S)
7985
end
86+
87+
for f in [+, -, *, \, /, ^]
88+
@eval promote_symtype(::$(typeof(f)),
89+
T::Type{<:Number},
90+
S::Type{<:Number}) = promote_type(T, S)
91+
end
92+
for f in [+, -, *]
93+
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T
94+
end
95+
8096
promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T
8197
Base.rem2pi(x::Symbolic, mode::Base.RoundingMode) = term(rem2pi, x, mode)
8298

@@ -93,25 +109,6 @@ rec_promote_symtype(f, x) = promote_symtype(f, x)
93109
rec_promote_symtype(f, x,y) = promote_symtype(f, x,y)
94110
rec_promote_symtype(f, x,y,z...) = rec_promote_symtype(f, promote_symtype(f, x,y), z...)
95111

96-
# Variadic methods
97-
for f in [+, *]
98-
99-
@eval (::$(typeof(f)))(x::Symbolic) = x
100-
101-
# single arg
102-
@eval function (::$(typeof(f)))(x::Symbolic, w::Number...)
103-
term($f, x,w...,
104-
type=rec_promote_symtype($f, map(symtype, (x,w...))...))
105-
end
106-
@eval function (::$(typeof(f)))(x::Number, y::Symbolic, w::Number...)
107-
term($f, x, y, w...,
108-
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
109-
end
110-
@eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w::Number...)
111-
term($f, x, y, w...,
112-
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
113-
end
114-
end
115112

116113
Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
117114
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)

src/rewriters.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,20 @@ function (rw::Fixpoint)(x)
107107
return x
108108
end
109109

110-
struct Walk{ord, C, threaded}
110+
struct Walk{ord, C, F, threaded}
111111
rw::C
112112
thread_cutoff::Int
113+
similarterm::F
113114
end
114115

115116
using .Threads
116117

117-
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100)
118-
Walk{:post, typeof(rw), threaded}(rw, thread_cutoff)
118+
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
119+
Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
119120
end
120121

121-
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100)
122-
Walk{:pre, typeof(rw), threaded}(rw, thread_cutoff)
122+
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
123+
Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
123124
end
124125

125126
struct PassThrough{C}
@@ -128,22 +129,22 @@ end
128129
(p::PassThrough)(x) = (y=p.rw(x); isnothing(y) ? x : y)
129130

130131
passthrough(x, default) = isnothing(x) ? default : x
131-
function (p::Walk{ord, C, false})(x) where {ord, C}
132+
function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
132133
@assert ord === :pre || ord === :post
133134
if istree(x)
134135
if ord === :pre
135136
x = p.rw(x)
136137
end
137138
if istree(x)
138-
x = similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
139+
x = p.similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
139140
end
140141
return ord === :post ? p.rw(x) : x
141142
else
142143
return p.rw(x)
143144
end
144145
end
145146

146-
function (p::Walk{ord, C, true})(x) where {ord, C}
147+
function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
147148
@assert ord === :pre || ord === :post
148149
if istree(x)
149150
if ord === :pre
@@ -158,7 +159,7 @@ function (p::Walk{ord, C, true})(x) where {ord, C}
158159
end
159160
end
160161
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
161-
t = similarterm(x, operation(x), args)
162+
t = p.similarterm(x, operation(x), args)
162163
end
163164
return ord === :post ? p.rw(t) : t
164165
else

src/rule.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ function (acr::ACRule)(term)
321321
itr = acr.sets(eachindex(args), acr.arity)
322322

323323
for inds in itr
324-
result = r(similarterm(term, f, @views args[inds]))
324+
result = r(Term{T}(f, @views args[inds]))
325325
if !isnothing(result)
326326
# Assumption: inds are unique
327327
length(args) == length(inds) && return result

0 commit comments

Comments
 (0)