Skip to content

Commit 9e17a8f

Browse files
committed
removed smrule macro and added commutativity checks to the rule macro
1 parent 9f650a4 commit 9e17a8f

File tree

4 files changed

+35
-45
lines changed

4 files changed

+35
-45
lines changed

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export Rewriters
5454
# A library for composing together expr -> expr functions
5555

5656
using Combinatorics: permutations, combinations
57-
export @rule, @acrule, @smrule, RuleSet
57+
export @rule, @acrule, RuleSet
5858

5959
# Rule type and @rule macro
6060
include("rule.jl")

src/matchers.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# 3. Callback: takes arguments Dictionary × Number of elements matched
77
#
88

9-
function matcher(val::Any; acSets = nothing)
9+
function matcher(val::Any, acSets)
1010
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
1111
if iscall(val)
1212
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
@@ -25,7 +25,7 @@ function matcher(val::Any; acSets = nothing)
2525
end
2626

2727
# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro
28-
function matcher(slot::Slot; acSets = nothing)
28+
function matcher(slot::Slot, acSets)
2929
function slot_matcher(next, data, bindings)
3030
!islist(data) && return nothing
3131
val = get(bindings, slot.name, nothing)
@@ -44,8 +44,8 @@ end
4444
# this is called only when defslot_term_matcher finds the operation and tries
4545
# to match it, so no default value used. So the same function as slot_matcher
4646
# can be used
47-
function matcher(defslot::DefSlot; acSets = nothing)
48-
matcher(Slot(defslot.name, defslot.predicate))
47+
function matcher(defslot::DefSlot, acSets)
48+
matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets
4949
end
5050

5151
# returns n == offset, 0 if failed
@@ -76,7 +76,7 @@ function trymatchexpr(data, value, n)
7676
end
7777
end
7878

79-
function matcher(segment::Segment; acSets=nothing)
79+
function matcher(segment::Segment, acSets)
8080
function segment_matcher(success, data, bindings)
8181
val = get(bindings, segment.name, nothing)
8282

@@ -105,7 +105,7 @@ function matcher(segment::Segment; acSets=nothing)
105105
end
106106

107107
function term_matcher_constructor(term, acSets)
108-
matchers = (matcher(operation(term); acSets=acSets), map(x->matcher(x;acSets=acSets), arguments(term))...,)
108+
matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,)
109109

110110
function loop(term, bindings′, matchers′) # Get it to compile faster
111111
if !islist(matchers′)
@@ -181,14 +181,21 @@ function term_matcher_constructor(term, acSets)
181181
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try
182182

183183
T = symtype(car(data))
184-
f = operation(car(data))
185-
data_args = arguments(car(data))
186-
187-
for inds in acSets(eachindex(data_args), length(arguments(term)))
188-
candidate = Term{T}(f, @views data_args[inds])
189-
190-
result = loop(candidate, bindings, matchers)
191-
result !== nothing && length(data_args) == length(inds) && return success(result,1)
184+
if T <: Number
185+
f = operation(car(data))
186+
data_args = arguments(car(data))
187+
188+
for inds in acSets(eachindex(data_args), length(arguments(term)))
189+
candidate = Term{T}(f, @views data_args[inds])
190+
191+
result = loop(candidate, bindings, matchers)
192+
result !== nothing && length(data_args) == length(inds) && return success(result,1)
193+
end
194+
# if car(data) does not subtype to number, it might not be commutative
195+
else
196+
# call the normal matcher
197+
result = loop(car(data), bindings, matchers)
198+
result !== nothing && return success(result, 1)
192199
end
193200
return nothing
194201
end
@@ -214,7 +221,7 @@ function defslot_term_matcher_constructor(term, acSets)
214221
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
215222
defslot = a[defslot_index]
216223
if length(a) == 2
217-
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1]; acSets = acSets)
224+
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets)
218225
else
219226
others = [a[i] for i in eachindex(a) if i != defslot_index]
220227
T = symtype(term)

src/rule.jl

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -387,24 +387,6 @@ _In the consequent pattern_: Use `(@ctx)` to access the context object on the ri
387387
of an expression.
388388
"""
389389
macro rule(expr)
390-
@assert expr.head == :call && expr.args[1] == :(=>)
391-
lhs = expr.args[2]
392-
rhs = rewrite_rhs(expr.args[3])
393-
keys = Symbol[]
394-
lhs_term = makepattern(lhs, keys)
395-
unique!(keys)
396-
quote
397-
$(__source__)
398-
lhs_pattern = $(lhs_term)
399-
Rule($(QuoteNode(expr)),
400-
lhs_pattern,
401-
matcher(lhs_pattern),
402-
__MATCHES__ -> $(makeconsequent(rhs)),
403-
rule_depth($lhs_term))
404-
end
405-
end
406-
407-
macro smrule(expr)
408390
@assert expr.head == :call && expr.args[1] == :(=>)
409391
lhs = expr.args[2]
410392
rhs = rewrite_rhs(expr.args[3])
@@ -417,7 +399,7 @@ macro smrule(expr)
417399
Rule(
418400
$(QuoteNode(expr)),
419401
lhs_pattern,
420-
matcher(lhs_pattern; acSets = permutations),
402+
matcher(lhs_pattern, permutations),
421403
__MATCHES__ -> $(makeconsequent(rhs)),
422404
rule_depth($lhs_term)
423405
)
@@ -455,7 +437,7 @@ macro capture(ex, lhs)
455437
lhs_pattern = $(lhs_term)
456438
__MATCHES__ = Rule($(QuoteNode(lhs)),
457439
lhs_pattern,
458-
matcher(lhs_pattern),
440+
matcher(lhs_pattern, nothing),
459441
identity,
460442
rule_depth($lhs_term))($(esc(ex)))
461443
if __MATCHES__ !== nothing
@@ -523,7 +505,7 @@ macro acrule(expr)
523505
lhs_pattern = $(lhs_term)
524506
rule = Rule($(QuoteNode(expr)),
525507
lhs_pattern,
526-
matcher(lhs_pattern; acSets = permutations),
508+
matcher(lhs_pattern, permutations),
527509
__MATCHES__ -> $(makeconsequent(rhs)),
528510
rule_depth($lhs_term))
529511
ACRule(permutations, rule, $arity)
@@ -545,7 +527,7 @@ macro ordered_acrule(expr)
545527
lhs_pattern = $(lhs_term)
546528
rule = Rule($(QuoteNode(expr)),
547529
lhs_pattern,
548-
matcher(lhs_pattern; acSets = combinations),
530+
matcher(lhs_pattern, combinations),
549531
__MATCHES__ -> $(makeconsequent(rhs)),
550532
rule_depth($lhs_term))
551533
ACRule(combinations, rule, $arity)

test/rewrite.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,20 @@ end
4848
end
4949

5050
@testset "Commutative + and *" begin
51-
r1 = @acrule exp(sin(~x) + cos(~x)) => ~x
51+
r1 = @rule exp(sin(~x) + cos(~x)) => ~x
52+
# using a or x changes the order of the arguments in the call
5253
@test r1(exp(sin(a)+cos(a))) === a
5354
@test r1(exp(sin(x)+cos(x))) === x
54-
r2 = @acrule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m)
55-
r3 = @acrule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m)
55+
r2 = @rule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m)
56+
r3 = @rule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m)
5657
@test r2((a+b)*(x+c)^b) === (a, b, x, c, b)
5758
@test r3((a+b)*(x+c)^b) === (a, b, x, c, b)
58-
rPredicate1 = @acrule ~x::(x->isa(x,Number)) + ~y => (~x, ~y)
59-
rPredicate2 = @acrule ~y + ~x::(x->isa(x,Number)) => (~x, ~y)
59+
rPredicate1 = @rule ~x::(x->isa(x,Number)) + ~y => (~x, ~y)
60+
rPredicate2 = @rule ~y + ~x::(x->isa(x,Number)) => (~x, ~y)
6061
@test rPredicate1(2+x) === (2, x)
6162
@test rPredicate2(2+x) === (2, x)
62-
r5 = @acrule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w)
63-
r6 = @acrule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w)
63+
r5 = @rule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w)
64+
r6 = @rule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w)
6465
@test r5(c*(a+b)+d) === (d, c, a, b)
6566
@test r6(c*(a+b)+d) === (d, c, a, b)
6667
end

0 commit comments

Comments
 (0)