Skip to content

Commit fa6ed8d

Browse files
Merge pull request #752 from Bumblebee00/commutative_operations
[WIP] Commutative operations and negative exponent match in rules
2 parents 3dc5e53 + 8612b2d commit fa6ed8d

File tree

5 files changed

+383
-129
lines changed

5 files changed

+383
-129
lines changed

docs/src/manual/rewrite.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ It works. This can be further simplified using Pythagorean identity and check it
180180
```jldoctest rewriteex
181181
pyid = @rule sin(~x)^2 + cos(~x)^2 => 1
182182
183-
pyid(cos(x)^2 + sin(x)^2) === nothing
183+
pyid(sin(x)^2 + 2sin(x)*cos(x) + cos(x)^2)===nothing
184184
185185
# output
186186
true
187187
```
188188

189-
Why does it return `nothing`? If we look at the rule, we see that the order of `sin(x)` and `cos(x)` is different. Therefore, in order to work, the rule needs to be associative-commutative.
189+
Why does it return `nothing`? If we look at the expression, we see that we have an additional addend `+ 2sin(x)*cos(x)`. Therefore, in order to work, the rule needs to be associative-commutative.
190190

191191
```jldoctest rewriteex
192192
acpyid = @acrule sin(~x)^2 + cos(~x)^2 => 1

src/matchers.jl

Lines changed: 183 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
# 3. Callback: takes arguments Dictionary × Number of elements matched
77
#
88

9-
function matcher(val::Any)
9+
function matcher(val::Any, acSets)
1010
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
1111
if iscall(val)
1212
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
13-
if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val))
14-
return defslot_term_matcher_constructor(val)
15-
# else return a normal term matcher
16-
else
17-
return term_matcher_constructor(val)
13+
# just two arguments bc defslot is only supported with operations with two args: *, ^, +
14+
if any(x -> isa(x, DefSlot), arguments(val))
15+
return defslot_term_matcher_constructor(val, acSets)
1816
end
17+
# else return a normal term matcher
18+
return term_matcher_constructor(val, acSets)
1919
end
2020

2121
function literal_matcher(next, data, bindings)
@@ -24,7 +24,8 @@ function matcher(val::Any)
2424
end
2525
end
2626

27-
function matcher(slot::Slot)
27+
# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro
28+
function matcher(slot::Slot, acSets)
2829
function slot_matcher(next, data, bindings)
2930
!islist(data) && return nothing
3031
val = get(bindings, slot.name, nothing)
@@ -43,8 +44,8 @@ end
4344
# this is called only when defslot_term_matcher finds the operation and tries
4445
# to match it, so no default value used. So the same function as slot_matcher
4546
# can be used
46-
function matcher(defslot::DefSlot)
47-
matcher(Slot(defslot.name, defslot.predicate))
47+
function matcher(defslot::DefSlot, acSets)
48+
matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets
4849
end
4950

5051
# returns n == offset, 0 if failed
@@ -75,7 +76,7 @@ function trymatchexpr(data, value, n)
7576
end
7677
end
7778

78-
function matcher(segment::Segment)
79+
function matcher(segment::Segment, acSets)
7980
function segment_matcher(success, data, bindings)
8081
val = get(bindings, segment.name, nothing)
8182

@@ -90,98 +91,202 @@ function matcher(segment::Segment)
9091
for i=length(data):-1:0
9192
subexpr = take_n(data, i)
9293

93-
if segment.predicate(subexpr)
94-
res = success(assoc(bindings, segment.name, subexpr), i)
95-
if res !== nothing
96-
break
97-
end
98-
end
94+
!segment.predicate(subexpr) && continue
95+
res = success(assoc(bindings, segment.name, subexpr), i)
96+
res !== nothing && break
9997
end
10098

10199
return res
102100
end
103101
end
104102
end
105103

106-
function term_matcher_constructor(term)
107-
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
104+
function term_matcher_constructor(term, acSets)
105+
matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,)
106+
107+
function loop(term, bindings′, matchers′) # Get it to compile faster
108+
if !islist(matchers′)
109+
if !islist(term)
110+
return bindings′
111+
end
112+
return nothing
113+
end
114+
car(matchers′)(term, bindings′) do b, n
115+
loop(drop_n(term, n), b, cdr(matchers′))
116+
end
117+
# explanation of above 3 lines:
118+
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
119+
# <------ next(b,n) ---------------------------->
120+
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
121+
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
122+
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
123+
# the length of the list, is considered empty
124+
end
108125

109-
function term_matcher(success, data, bindings)
110-
!islist(data) && return nothing # if data is not a list, return nothing
111-
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
126+
# if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent
127+
if operation(term) === ^
128+
function pow_term_matcher(success, data, bindings)
129+
!islist(data) && return nothing # if data is not a list, return nothing
130+
data = car(data) # from (..., ) to ...
131+
!iscall(data) && return nothing # if first element is not a call, return nothing
132+
133+
result = loop(data, bindings, matchers)
134+
result !== nothing && return success(result, 1)
135+
136+
frankestein = nothing
137+
if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1)
138+
# if data is of the alternative form (1/...)^(...)
139+
one_over_smth = arguments(data)[1]
140+
T = symtype(one_over_smth)
141+
frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]])
142+
elseif (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^)
143+
# if data is of the alternative form 1/(...)^(...)
144+
denominator = arguments(data)[2]
145+
T = symtype(denominator)
146+
frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]])
147+
elseif (operation(data) === /) && isequal(arguments(data)[1], 1)
148+
# if data is of the alternative form 1/(...), it might match with exponent = -1
149+
denominator = arguments(data)[2]
150+
T = symtype(denominator)
151+
frankestein = Term{T}(^, [denominator, -1])
152+
elseif operation(data)===exp
153+
# if data is a exp call, it might match with base e
154+
T = symtype(arguments(data)[1])
155+
frankestein = Term{T}(^,[ℯ, arguments(data)[1]])
156+
elseif operation(data)===sqrt
157+
# if data is a sqrt call, it might match with exponent 1//2
158+
T = symtype(arguments(data)[1])
159+
frankestein = Term{T}(^,[arguments(data)[1], 1//2])
160+
end
161+
162+
if frankestein !==nothing
163+
result = loop(frankestein, bindings, matchers)
164+
result !== nothing && return success(result, 1)
165+
end
166+
167+
return nothing
168+
end
169+
return pow_term_matcher
170+
# if we want to do commutative checks, i.e. call matcher with different order of the arguments
171+
elseif acSets!==nothing && operation(term) in [+, *]
172+
function commutative_term_matcher(success, data, bindings)
173+
!islist(data) && return nothing # if data is not a list, return nothing
174+
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
175+
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try
176+
177+
T = symtype(car(data))
178+
if T <: Number
179+
f = operation(car(data))
180+
data_args = arguments(car(data))
181+
182+
for inds in acSets(eachindex(data_args), length(data_args))
183+
candidate = Term{T}(f, @views data_args[inds])
112184

113-
function loop(term, bindings′, matchers′) # Get it to compile faster
114-
if !islist(matchers′)
115-
if !islist(term)
116-
return success(bindings′, 1)
185+
result = loop(candidate, bindings, matchers)
186+
result !== nothing && return success(result,1)
117187
end
118-
return nothing
188+
# if car(data) does not subtype to number, it might not be commutative
189+
else
190+
# call the normal matcher
191+
result = loop(car(data), bindings, matchers)
192+
result !== nothing && return success(result, 1)
119193
end
120-
car(matchers′)(term, bindings′) do b, n
121-
loop(drop_n(term, n), b, cdr(matchers′))
194+
return nothing
195+
end
196+
return commutative_term_matcher
197+
# if the operation is sqrt, we have to match also ^(1//2)
198+
elseif operation(term)==sqrt
199+
function sqrt_matcher(success, data, bindings)
200+
!islist(data) && return nothing # if data is not a list, return nothing
201+
data = car(data)
202+
!iscall(data) && return nothing # if first element is not a call, return nothing
203+
204+
# do the normal matcher
205+
result = loop(data, bindings, matchers)
206+
result !== nothing && return success(result, 1)
207+
208+
if (operation(data) === ^) && (arguments(data)[2] === 1//2)
209+
T = symtype(arguments(data)[1])
210+
frankestein = Term{T}(sqrt,[arguments(data)[1]])
211+
result = loop(frankestein, bindings, matchers)
212+
result !== nothing && return success(result, 1)
122213
end
123-
# explanation of above 3 lines:
124-
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
125-
# <------ next(b,n) ---------------------------->
126-
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
127-
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
128-
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
129-
# the length of the list, is considered empty
214+
return nothing
130215
end
216+
return sqrt_matcher
217+
# if the operation is exp, we have to match also ℯ^
218+
elseif operation(term)==exp
219+
function exp_matcher(success, data, bindings)
220+
!islist(data) && return nothing # if data is not a list, return nothing
221+
data = car(data)
222+
!iscall(data) && return nothing # if first element is not a call, return nothing
223+
224+
# do the normal matcher
225+
result = loop(data, bindings, matchers)
226+
result !== nothing && return success(result, 1)
131227

132-
loop(car(data), bindings, matchers) # Try to eat exactly one term
228+
if (operation(data) === ^) && (arguments(data)[1] === ℯ)
229+
T = symtype(arguments(data)[2])
230+
frankestein = Term{T}(exp,[arguments(data)[2]])
231+
result = loop(frankestein, bindings, matchers)
232+
result !== nothing && return success(result, 1)
233+
end
234+
return nothing
235+
end
236+
return exp_matcher
237+
else
238+
function term_matcher(success, data, bindings)
239+
!islist(data) && return nothing # if data is not a list, return nothing
240+
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
241+
242+
result = loop(car(data), bindings, matchers)
243+
result !== nothing && return success(result, 1)
244+
return nothing
245+
end
246+
return term_matcher
133247
end
134248
end
135249

136250
# creates a matcher for a term containing a defslot, such as:
137251
# (~x + ...complicated pattern...) * ~!y
138252
# normal part (can bee a tree) operation defslot part
139253

140-
# defslot_term_matcher works like this:
141-
# checks whether data starts with the default operation.
142-
# if yes (1): continues like term_matcher
143-
# if no checks whether data matches the normal part
144-
# if no returns nothing, rule is not applied
145-
# if yes (2): adds the pair (default value name, default value) to the found bindings and
146-
# calls the success function like term_matcher would do
147-
148-
function defslot_term_matcher_constructor(term)
149-
a = arguments(term) # length two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
150-
matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term
151-
254+
# Note: there is a bit of a waste here bc the matcher get created twice, both
255+
# in the normal_matcher and in defslot_matcher and other_part_matcher
256+
function defslot_term_matcher_constructor(term, acSets)
257+
a = arguments(term)
152258
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
153259
defslot = a[defslot_index]
260+
defslot_matcher = matcher(defslot, acSets)
261+
if length(a) == 2
262+
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets)
263+
else
264+
# if we hare here the operation is a multiplication or sum of n>2 terms
265+
# (because ^ cannot have more than 2 terms).
266+
# creates the term matcher of the multiplication or sum of n-1 terms
267+
others = [a[i] for i in eachindex(a) if i != defslot_index]
268+
T = symtype(term)
269+
f = operation(term)
270+
other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets)
271+
end
154272

155-
function defslot_term_matcher(success, data, bindings)
156-
# if data is not a list, return nothing
157-
!islist(data) && return nothing
158-
# if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation)
159-
if !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation)
160-
other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part
161-
162-
# checks whether it matches the normal part
163-
# <-----------------(2)------------------------------->
164-
bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)
165-
166-
if bindings === nothing
167-
return nothing
168-
end
169-
return success(bindings, 1)
170-
end
171-
172-
# (1)
173-
function loop(term, bindings′, matchers′) # Get it to compile faster
174-
if !islist(matchers′)
175-
if !islist(term)
176-
return success(bindings′, 1)
177-
end
178-
return nothing
179-
end
180-
car(matchers′)(term, bindings′) do b, n
181-
loop(drop_n(term, n), b, cdr(matchers′))
182-
end
183-
end
273+
normal_matcher = term_matcher_constructor(term, acSets)
184274

185-
loop(car(data), bindings, matchers) # Try to eat exactly one term
275+
function defslot_term_matcher(success, data, bindings)
276+
!islist(data) && return nothing # if data is not a list, return nothing
277+
# call the normal matcher, with success function that returns the bindings (foo1)
278+
# <-foo1->
279+
result = normal_matcher((b,n)->b, data, bindings)
280+
result !== nothing && return success(result, 1)
281+
# if no match, try to match with a defslot.
282+
# checks whether it matches the normal part if yes executes foo2
283+
# foo2: adds the pair (default value name, default value) to the found bindings
284+
# after checking predicate and presence in the bindings. If added
285+
# successfully returns the bindings (foo3), otherwise return nothing
286+
# <-------------------foo2----------------------------------->
287+
# <-foo3->
288+
result = other_part_matcher((b,n)->defslot_matcher((b,n)->b, (defslot.defaultValue,), b), data, bindings)
289+
result !== nothing && return success(result, 1)
290+
nothing
186291
end
187292
end

0 commit comments

Comments
 (0)