Skip to content

Commit 83e12a6

Browse files
authored
Merge pull request #118 from JuliaSymbolics/s/generic
use interface on the matched expression instead of requiring it to be a `Term`
2 parents b08fa34 + 701fd9b commit 83e12a6

File tree

12 files changed

+157
-153
lines changed

12 files changed

+157
-153
lines changed

docs/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ using SymbolicUtils # hide
2424

2525
## Interfacing
2626

27-
{{doc to_symbolic to_symbolic fn}}
28-
2927
{{doc istree istree fn}}
3028

3129
{{doc operation operation fn}}
3230

3331
{{doc arguments arguments fn}}
3432

33+
{{doc similarterm similarterm fn}}
34+
3535
## Rewriters
3636

3737
{{doc @rule @rule macro}}

docs/interface.md

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,7 @@ This may sound like a roundabout way of doing it, but it can be really fast. In
1717

1818
## Defining the interface
1919

20-
SymbolicUtils uses a function `to_symbolic` to convert aribtrarty types to it's own internal types. Our intention is for SymbolicUtils to be useful even for packages with their own custom symbolic types which
21-
differ from those offered by SymbolicUtils. To this end, SymbolicUtils provides an interface which if implemented, will enable automatic conversion of types to SymbolicUtils types.
22-
23-
* an `operation`, (i.e. function to apply)
24-
* `arguments` which the `operation` is applied to
25-
* `variable` types which are the atoms from which the expression tree is built
26-
* optionally, a type which should `typeof(operation(arguments...))` should return if it were to be run.
20+
SymbolicUtils matchers can match any Julia object that implements an interface to traverse it as a tree.
2721

2822
In particular, the following methods should be defined for an expression tree type `T` with symbol types `S` to work
2923
with SymbolicUtils.jl
@@ -48,20 +42,16 @@ Returns the arguments (a `Vector`) for an expression tree.
4842
Called only if `istree(x)` is `true`. Part of the API required
4943
for `simplify` to work. Other required methods are `operation` and `istree`
5044

51-
#### `to_symbolic(x::S)`
52-
Convert your variable type to a `SymbolicUtils.Sym`. Suppose you have
53-
```julia
54-
struct MySymbol
55-
s::Symbol
56-
end
57-
```
58-
which could represent any type symbolically, then you would define
59-
```julia
60-
SymbolicUtils.to_symbolic(s::MySymbol) = SymbolicUtils.Sym(s.s)
61-
```
45+
In addition, the methods for `Base.hash` and `Base.isequal` should also be implemented by the types for the purposes of substitution and equality matching respectively.
6246

6347
### Optional
6448

49+
#### `similarterm(t::MyType, f, args)`
50+
51+
Construct a new term with the operation `f` and arguments `args`, the term should be similar to `t` in type. if `t` is a `Term` object a new Term is created with the same symtype as `t`. If not, the result is computed as `f(args...)`. Defining this method for your term type will reduce any performance loss in performing `f(args...)` (esp. the splatting, and redundant type computation).
52+
53+
The below two functions are internal to SymbolicUtils
54+
6555
#### `symtype(x)`
6656

6757
The supposed type of values in the domain of x. Tracing tools can use this type to
@@ -102,38 +92,24 @@ How can we use SymbolicUtils.jl to convert `ex` to `(-)(:x, 1)`? We simply imple
10292
`operation`, `arguments` and `to_symbolic` and we'll be off to the races:
10393
```julia:piracy2
10494
using SymbolicUtils
105-
using SymbolicUtils: Sym, Term, istree, operation, arguments, to_symbolic
95+
using SymbolicUtils: istree, operation, arguments, similarterm
10696
10797
SymbolicUtils.istree(ex::Expr) = ex.head == :call
10898
SymbolicUtils.operation(ex::Expr) = ex.args[1]
10999
SymbolicUtils.arguments(ex::Expr) = ex.args[2:end]
110-
SymbolicUtils.to_symbolic(s::Symbol) = Sym(s)
111100
112101
@show simplify(ex)
113102
114103
dump(simplify(ex))
115104
```
116-
\out{piracy2}
117-
118-
this thing returns a `Term{Any}`, but it's not hard to convert back to `Expr`:
119-
120-
```julia:piracy3
121-
to_expr(t::Term) = Expr(:call, operation(t), to_expr.(arguments(t))...)
122-
to_expr(x) = x
123-
124-
@show expr = to_expr(simplify(ex))
125-
126-
dump(expr)
127-
```
128-
\out{piracy3}
129105

106+
There was no simplification, because by default SymbolicUtils assumes that the expressoins are of type Any and no particular rules apply. Let's change this by saying that the symbolic type (symtype) of an Expr or Symbol object is actually Real.
130107

131-
Now suppose we actaully wanted all `Symbol`s to be treated as `Real` numbers. We can simply define
132108
```julia:piracy4
133-
SymbolicUtils.symtype(s::Symbol) = Real
109+
SymbolicUtils.symtype(s::Expr) = Real
134110
135111
dump(simplify(ex))
136112
```
137113
\out{piracy4}
138114

139-
and now all our analysis is able to figure out that the `Term`s are `Number`s.
115+
Now SymbolicUtils is able to apply the Number simplification rule to Expr.

src/SymbolicUtils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module SymbolicUtils
22

3-
export @syms, term, @fun, showraw
3+
export @syms, term, showraw
44

55
# Sym, Term and other types
66
include("types.jl")
@@ -21,7 +21,7 @@ include("rewriters.jl")
2121
using .Rewriters
2222

2323
using Combinatorics: permutations, combinations
24-
export @rule, @acrule, @arule, RuleSet
24+
export @rule, @acrule, RuleSet
2525

2626
# Rule type and @rule macro
2727
include("rule.jl")

src/abstractalgebra.jl

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@ end
2121
function labels!(dicts, t)
2222
if t isa Integer
2323
return t
24-
elseif t isa Term && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
24+
elseif istree(t) && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
2525
tt = arguments(t)
26-
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), arguments(t)))
27-
elseif t isa Term && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
28-
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), arguments(t)))
26+
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
27+
elseif istree(t) && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
28+
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
2929
else
3030
sym2term, term2sym = dicts
3131
if haskey(term2sym, t)
3232
return term2sym[t]
3333
end
34-
if t isa Term
34+
if istree(t)
3535
tt = arguments(t)
3636
sym = Sym{symtype(t)}(gensym(nameof(operation(t))))
3737
dicts2 = _dicts(dicts[2])
38-
sym2term[sym] = Term{symtype(t)}(operation(t),
39-
map(x->to_mpoly(x, dicts)[1], arguments(t)))
38+
sym2term[sym] = similarterm(t, operation(t),
39+
map(x->to_mpoly(x, dicts)[1], arguments(t)))
4040
else
4141
sym = Sym{symtype(t)}(gensym("literal"))
4242
sym2term[sym] = t
@@ -92,16 +92,16 @@ let
9292
end
9393
end
9494

95-
function to_term(x, dict)
95+
function to_term(reference, x, dict)
9696
syms = Dict(zip(nameof.(keys(dict)), keys(dict)))
9797
dict = copy(dict)
9898
for (k, v) in dict
99-
dict[k] = _to_term(v, dict, syms)
99+
dict[k] = _to_term(reference, v, dict, syms)
100100
end
101-
_to_term(x, dict, syms)
101+
_to_term(reference, x, dict, syms)
102102
end
103103

104-
function _to_term(x::MPoly, dict, syms)
104+
function _to_term(reference, x::MPoly, dict, syms)
105105

106106
function mul_coeffs(exps, ring)
107107
l = length(syms)
@@ -112,8 +112,7 @@ function _to_term(x::MPoly, dict, syms)
112112
elseif length(monics) == 0
113113
return 1
114114
else
115-
T = reduce((x,y)->promote_symtype(*, x,y), symtype.(monics))
116-
return Term{T}(*, monics)
115+
return similarterm(reference, *, monics)
117116
end
118117
end
119118

@@ -123,27 +122,29 @@ function _to_term(x::MPoly, dict, syms)
123122
elseif length(monoms) == 1
124123
t = !isone(x.coeffs[1]) ? monoms[1] * x.coeffs[1] : monoms[1]
125124
else
126-
T = reduce((x,y)->promote_symtype(+, x,y), symtype.(monoms))
127-
t = Term{T}(+, map((x,y)->isone(y) ? x : y*x, monoms, x.coeffs[1:length(monoms)]))
125+
t = similarterm(reference,
126+
+,
127+
map((x,y)->isone(y) ? x : y*x,
128+
monoms, x.coeffs[1:length(monoms)]))
128129
end
129130

130131
substitute(t, dict, fold=false)
131132
end
132133

133-
function _to_term(x, dict, vars)
134-
if haskey(dict, x)
135-
return dict[x]
134+
function _to_term(reference, x, dict, vars)
135+
if istree(x)
136+
t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)))
136137
else
137-
return x
138+
if haskey(dict, x)
139+
return dict[x]
140+
else
141+
return x
142+
end
138143
end
139144
end
140145

141-
function _to_term(x::Term, dict, vars)
142-
t=Term{symtype(x)}(operation(x), _to_term.(arguments(x), (dict,), (vars,)))
143-
end
144-
145146
<(a::MPoly, b::MPoly) = false
146147

147148
function polynormalize(x)
148-
to_term(to_mpoly(x)...)
149+
to_term(x, to_mpoly(x)...)
149150
end

src/matchers.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
#
88
function matcher(val::Any)
99
function literal_matcher(next, data, bindings)
10-
!isempty(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
10+
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
1111
end
1212
end
1313

1414
function matcher(slot::Slot)
1515
function slot_matcher(next, data, bindings)
16-
isempty(data) && return
16+
!islist(data) && return
1717
val = get(bindings, slot.name, nothing)
1818
if val !== nothing
1919
if isequal(val, car(data))
@@ -29,10 +29,10 @@ end
2929

3030
# returns n == offset, 0 if failed
3131
function trymatchexpr(data, value, n)
32-
if isempty(value)
32+
if !islist(value)
3333
return n
3434
elseif islist(value) && islist(data)
35-
if isempty(data)
35+
if !islist(data)
3636
# didn't fully match
3737
return nothing
3838
end
@@ -42,14 +42,14 @@ function trymatchexpr(data, value, n)
4242
value = cdr(value)
4343
data = cdr(data)
4444

45-
if isempty(value)
45+
if !islist(value)
4646
return n
47-
elseif isempty(data)
47+
elseif !islist(data)
4848
return nothing
4949
end
5050
end
5151

52-
return isempty(value) ? n : nothing
52+
return !islist(value) ? n : nothing
5353
elseif isequal(value, data)
5454
return n + 1
5555
end
@@ -87,12 +87,12 @@ function matcher(term::Term)
8787
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
8888
function term_matcher(success, data, bindings)
8989

90-
isempty(data) && return nothing
91-
!(car(data) isa Term) && return nothing
90+
!islist(data) && return nothing
91+
!(istree(car(data))) && return nothing
9292

9393
function loop(term, bindings′, matchers′) # Get it to compile faster
94-
if isempty(matchers′)
95-
if isempty(term)
94+
if !islist(matchers′)
95+
if !islist(term)
9696
return success(bindings′, 1)
9797
end
9898
return nothing

src/ordering.jl

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,43 @@
99
<(a::Number, b::Symbolic) = true
1010

1111
arglength(a) = length(arguments(a))
12-
function <(a::Sym, b::Term)
13-
args = arguments(b)
14-
if length(args) === 2
15-
n1, n2 = !isnumber(args[1]) , !isnumber(args[2])
16-
if n1 && n2
17-
# both subterms are terms, so it's definitely firster
18-
return true
19-
elseif n1
20-
return isequal(a, args[1]) || a <ₑ args[1]
21-
elseif n2
22-
return isequal(a, args[2]) || a <ₑ args[2]
12+
function <(a, b)
13+
if !istree(a) && !istree(b)
14+
T = typeof(a)
15+
S = typeof(b)
16+
T===S ? isless(a, b) : nameof(T) < nameof(S)
17+
elseif istree(b) && !istree(a)
18+
args = arguments(b)
19+
if length(args) === 2
20+
n1, n2 = !isnumber(args[1]) , !isnumber(args[2])
21+
if n1 && n2
22+
# both subterms are terms, so it's definitely firster
23+
return true
24+
elseif n1
25+
return isequal(a, args[1]) || a <ₑ args[1]
26+
elseif n2
27+
return isequal(a, args[2]) || a <ₑ args[2]
28+
else
29+
# both arguments are not numbers
30+
# This case when a <ₑ Term(^, [1,-1])
31+
# so this term should go to the left.
32+
return false
33+
end
34+
elseif length(args) === 1
35+
# make sure a < sin(a) < b^2 < b
36+
if isequal(a, args[1])
37+
return true # e.g sin(a)*a should become a*sin(a)
38+
else
39+
return a<ₑargs[1]
40+
end
2341
else
24-
# both arguments are not numbers
25-
# This case when a <ₑ Term(^, [1,-1])
26-
# so this term should go to the left.
42+
# variables to the right
2743
return false
2844
end
29-
elseif length(args) === 1
30-
# make sure a < sin(a) < b^2 < b
31-
if isequal(a, args[1])
32-
return true # e.g sin(a)*a should become a*sin(a)
33-
else
34-
return a<ₑargs[1]
35-
end
45+
elseif istree(a) && istree(b)
46+
cmp_term_term(a,b)
3647
else
37-
# variables to the right
38-
return false
48+
!(b <ₑ a)
3949
end
4050
end
4151

@@ -61,9 +71,8 @@ function <ₑ(a::Symbol, b::Symbol)
6171
end
6272

6373
<(a::Sym, b::Sym) = a.name < b.name
64-
<(a::T, b::S) where {T, S} = T===S ? isless(a, b) : nameof(T) < nameof(S)
6574

66-
function <(a::Term, b::Term)
75+
function cmp_term_term(a, b)
6776
la = arglength(a)
6877
lb = arglength(b)
6978

src/rewriters.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ rewriters.
2626
2727
"""
2828
module Rewriters
29-
using SymbolicUtils: @timer, is_operation, istree, symtype, Term, operation, arguments,
30-
node_count
29+
using SymbolicUtils: @timer, is_operation, istree, operation, similarterm, arguments, node_count
3130

3231
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough
3332

@@ -149,9 +148,7 @@ function (p::Walk{ord, C, false})(x) where {ord, C}
149148
x = p.rw(x)
150149
end
151150
if istree(x)
152-
x = Term{symtype(x)}(operation(x),
153-
map(t->PassThrough(p)(t),
154-
arguments(x)))
151+
x = similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
155152
end
156153
return ord === :post ? p.rw(x) : x
157154
else
@@ -174,7 +171,7 @@ function (p::Walk{ord, C, true})(x) where {ord, C}
174171
end
175172
end
176173
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
177-
t = Term{symtype(x)}(operation(x), args)
174+
t = similarterm(x, operation(x), args)
178175
end
179176
return ord === :post ? p.rw(t) : t
180177
else

0 commit comments

Comments
 (0)