From a07bdbc88f80edda5072ec489f7ad1bbe7112faa Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sat, 14 Jun 2025 18:52:35 +0200 Subject: [PATCH 01/26] first version, really caothic, and doesn't work with defslot powers --- src/matchers.jl | 63 +++++++++++++++++++++++++++++++++++++++++++++++++ test/rewrite.jl | 17 ++++++++++--- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index e08462747..d0c98cc5a 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -10,8 +10,12 @@ function matcher(val::Any) # if val is a call (like an operation) creates a term matcher or term matcher with defslot if iscall(val) # if has two arguments and one of them is a DefSlot, create a term matcher with defslot + # just two arguments bc defslot is only supported with operations with two args: *, ^, + if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val)) return defslot_term_matcher_constructor(val) + # else (a)^(b) can also match 1/( (a)^(b) ) , just with b of oppsite sign + elseif operation(val) == ^ + return neg_pow_term_matcher_constructor(val) # else return a normal term matcher else return term_matcher_constructor(val) @@ -40,6 +44,20 @@ function matcher(slot::Slot) end end +function opposite_sign_matcher(slot::Slot) + function slot_matcher(next, data, bindings) + !islist(data) && return nothing + val = get(bindings, slot.name, nothing) + if val !== nothing + if isequal(val, car(data)) + return next(bindings, 1) + end + elseif slot.predicate(car(data)) + next(assoc(bindings, slot.name, -car(data)), 1) # this - is the only differenct wrt matcher(slot::Slot) + end + end +end + # this is called only when defslot_term_matcher finds the operation and tries # to match it, so no default value used. So the same function as slot_matcher # can be used @@ -133,6 +151,51 @@ function term_matcher_constructor(term) end end +# (a)^(b) can also match 1/( (a)^(b) ) , just with b of oppsite sign +function neg_pow_term_matcher_constructor(term) + matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) + + function neg_pow_term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing + + function loop(term, bindings′, matchers′) + if !islist(matchers′) + if !islist(term) + return success(bindings′, 1) + end + return nothing + end + car(matchers′)(term, bindings′) do b, n + loop(drop_n(term, n), b, cdr(matchers′)) + end + end + + result = loop(car(data), bindings, matchers) + # if data is of the form 1/(...)^(...), it might match with negative exponent + if result === nothing && (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) + denominator = arguments(car(data))[2] + # let's say data = a^b with a and b can be whatever + # if b is not a number then call the loop function with a^-b + if !isa(arguments(denominator)[2], Number) + frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) + result = loop(frankestein, bindings, matchers) + else + # if b is a number, like 3, we cant call loop with a^-3 bc it + # will automatically transform into 1/a^3. Therfore we need to + # create a matcher that flips the sign of the exponent. I created + # this matecher just for `Slot`s and not for terms, because if b + # is a number and not a call, certainly doesn't match a term (I hope). + if isa(arguments(term)[2], Slot) + matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? + result = loop(denominator, bindings, matchers2) + end + end + end + result + end +end + # creates a matcher for a term containing a defslot, such as: # (~x + ...complicated pattern...) * ~!y # normal part (can bee a tree) operation defslot part diff --git a/test/rewrite.jl b/test/rewrite.jl index b8996f79e..e0eadd2fa 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -77,9 +77,20 @@ end @test r_pow2(a+b) === 1 r_mix = @rule (~x + (~y)*(~!c))^(~!m) => ~m + ~c - @test r_mix((a + b*c)^2) === 2 + c - @test r_mix((a + b*c)) === 1 + c - @test r_mix((a + b)) === 2 #1+1 + @test r_mix((a + b*c)^2) === (2, c) + @test r_mix((a + b*c)) === (1, c) + @test r_mix((a + b)) === (1, 1) +end + +@testset "1/power matches power with exponent of opposite sign" begin + r1 = @rule (~x)^(~y) => (~x, ~y) # rule with slot as exponent + @test r1(1/a^b) === (a, -b) # uses frankestein + @test r1(1/a^(b+2c)) === (a, -b-2c) # uses frankestein + @test r1(1/a^2) === (a, -2) # uses opposite_sign_matcher + + r2 = @rule (~x)^(~y + ~z) => (~x, ~y, ~z) # rule with term as exponent + @test r2(1/a^(b+2c)) === (a, -b, -2c) # uses frankestein + @test r2(1/a^3) === nothing # should use a term_matcher that flips the sign, but is not implemented end using SymbolicUtils: @capture From 12843da575a1f4efe875e60a0636c1fa397eca0b Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sat, 14 Jun 2025 19:34:29 +0200 Subject: [PATCH 02/26] second version, really caothic, but works with defslotpowers --- src/matchers.jl | 135 ++++++++++++++++++++++++++---------------------- test/rewrite.jl | 6 +++ 2 files changed, 79 insertions(+), 62 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index d0c98cc5a..5d8aaa4b3 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -13,9 +13,6 @@ function matcher(val::Any) # just two arguments bc defslot is only supported with operations with two args: *, ^, + if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val)) return defslot_term_matcher_constructor(val) - # else (a)^(b) can also match 1/( (a)^(b) ) , just with b of oppsite sign - elseif operation(val) == ^ - return neg_pow_term_matcher_constructor(val) # else return a normal term matcher else return term_matcher_constructor(val) @@ -58,6 +55,10 @@ function opposite_sign_matcher(slot::Slot) end end +function opposite_sign_matcher(defslot::DefSlot) + opposite_sign_matcher(Slot(defslot.name, defslot.predicate)) +end + # this is called only when defslot_term_matcher finds the operation and tries # to match it, so no default value used. So the same function as slot_matcher # can be used @@ -147,48 +148,27 @@ function term_matcher_constructor(term) # the length of the list, is considered empty end - loop(car(data), bindings, matchers) # Try to eat exactly one term - end -end - -# (a)^(b) can also match 1/( (a)^(b) ) , just with b of oppsite sign -function neg_pow_term_matcher_constructor(term) - matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) - - function neg_pow_term_matcher(success, data, bindings) - !islist(data) && return nothing # if data is not a list, return nothing - !iscall(car(data)) && return nothing # if first element is not a call, return nothing - - function loop(term, bindings′, matchers′) - if !islist(matchers′) - if !islist(term) - return success(bindings′, 1) - end - return nothing - end - car(matchers′)(term, bindings′) do b, n - loop(drop_n(term, n), b, cdr(matchers′)) - end - end - result = loop(car(data), bindings, matchers) - # if data is of the form 1/(...)^(...), it might match with negative exponent - if result === nothing && (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) - denominator = arguments(car(data))[2] - # let's say data = a^b with a and b can be whatever - # if b is not a number then call the loop function with a^-b - if !isa(arguments(denominator)[2], Number) - frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) - result = loop(frankestein, bindings, matchers) - else - # if b is a number, like 3, we cant call loop with a^-3 bc it - # will automatically transform into 1/a^3. Therfore we need to - # create a matcher that flips the sign of the exponent. I created - # this matecher just for `Slot`s and not for terms, because if b - # is a number and not a call, certainly doesn't match a term (I hope). - if isa(arguments(term)[2], Slot) - matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? - result = loop(denominator, bindings, matchers2) + # if data is of the alternative form 1/(...)^(...), it might match with negative exponent + if operation(term)==^ + alternative_form = (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) + if result === nothing && alternative_form + denominator = arguments(car(data))[2] + # let's say data = a^b with a and b can be whatever + # if b is not a number then call the loop function with a^-b + if !isa(arguments(denominator)[2], Number) + frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) + result = loop(frankestein, bindings, matchers) + else + # if b is a number, like 3, we cant call loop with a^-3 bc it + # will automatically transform into 1/a^3. Therfore we need to + # create a matcher that flips the sign of the exponent. I created + # this matecher just for `Slot`s and not for terms, because if b + # is a number and not a call, certainly doesn't match a term (I hope). + if isa(arguments(term)[2], Slot) + matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? + result = loop(denominator, bindings, matchers2) + end end end end @@ -205,7 +185,7 @@ end # if yes (1): continues like term_matcher # if no checks whether data matches the normal part # if no returns nothing, rule is not applied -# if yes (2): adds the pair (default value name, default value) to the found bindings and +# if yes (3): adds the pair (default value name, default value) to the found bindings and # calls the success function like term_matcher would do function defslot_term_matcher_constructor(term) @@ -218,8 +198,53 @@ function defslot_term_matcher_constructor(term) function defslot_term_matcher(success, data, bindings) # if data is not a list, return nothing !islist(data) && return nothing + result = nothing + if iscall(car(data)) + # (1) + function loop(term, bindings′, matchers′) # Get it to compile faster + if !islist(matchers′) + if !islist(term) + return success(bindings′, 1) + end + return nothing + end + car(matchers′)(term, bindings′) do b, n + loop(drop_n(term, n), b, cdr(matchers′)) + end + end + + result = loop(car(data), bindings, matchers) # Try to eat exactly one term + # if data is of the alternative form 1/(...)^(...), it might match with negative exponent + if operation(term)==^ + alternative_form = (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) + if result === nothing && alternative_form + denominator = arguments(car(data))[2] + # let's say data = a^b with a and b can be whatever + # if b is not a number then call the loop function with a^-b + if !isa(arguments(denominator)[2], Number) + frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) + result = loop(frankestein, bindings, matchers) + else + # if b is a number, like 3, we cant call loop with a^-3 bc it + # will automatically transform into 1/a^3. Therfore we need to + # create a matcher that flips the sign of the exponent. I created + # this matecher just for `DefSlot`s and not for terms, because if b + # is a number and not a call, certainly doesn't match a term (I hope). + if isa(arguments(term)[2], DefSlot) + matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? + result = loop(denominator, bindings, matchers2) + end + end + end + end + # (2) + if result !== nothing + return result + end + end + # if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation) - if !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation) + if ( !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation) ) other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part # checks whether it matches the normal part @@ -229,22 +254,8 @@ function defslot_term_matcher_constructor(term) if bindings === nothing return nothing end - return success(bindings, 1) + result = success(bindings, 1) end - - # (1) - function loop(term, bindings′, matchers′) # Get it to compile faster - if !islist(matchers′) - if !islist(term) - return success(bindings′, 1) - end - return nothing - end - car(matchers′)(term, bindings′) do b, n - loop(drop_n(term, n), b, cdr(matchers′)) - end - end - - loop(car(data), bindings, matchers) # Try to eat exactly one term + result end end diff --git a/test/rewrite.jl b/test/rewrite.jl index e0eadd2fa..3cd545f2c 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -91,6 +91,12 @@ end r2 = @rule (~x)^(~y + ~z) => (~x, ~y, ~z) # rule with term as exponent @test r2(1/a^(b+2c)) === (a, -b, -2c) # uses frankestein @test r2(1/a^3) === nothing # should use a term_matcher that flips the sign, but is not implemented + + r1defslot = @rule (~x)^(~!y) => (~x, ~y) # rule with slot as exponent + @test r1defslot(1/a^b) === (a, -b) # uses frankestein + @test r1defslot(1/a^(b+2c)) === (a, -b-2c) # uses frankestein + @test r1defslot(1/a^2) === (a, -2) # uses opposite_sign_matcher + @test r1defslot(a) === (a, 1) end using SymbolicUtils: @capture From d81145ef33580a41cf2423385d8ace6e76877a75 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Wed, 18 Jun 2025 12:55:11 +0200 Subject: [PATCH 03/26] fix typo --- test/rewrite.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rewrite.jl b/test/rewrite.jl index 3cd545f2c..194b9fe92 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -76,7 +76,7 @@ end @test r_pow2((a+b)^c) === c @test r_pow2(a+b) === 1 - r_mix = @rule (~x + (~y)*(~!c))^(~!m) => ~m + ~c + r_mix = @rule (~x + (~y)*(~!c))^(~!m) => (~m, ~c) @test r_mix((a + b*c)^2) === (2, c) @test r_mix((a + b*c)) === (1, c) @test r_mix((a + b)) === (1, 1) @@ -92,7 +92,7 @@ end @test r2(1/a^(b+2c)) === (a, -b, -2c) # uses frankestein @test r2(1/a^3) === nothing # should use a term_matcher that flips the sign, but is not implemented - r1defslot = @rule (~x)^(~!y) => (~x, ~y) # rule with slot as exponent + r1defslot = @rule (~x)^(~!y) => (~x, ~y) # rule with defslot as exponent @test r1defslot(1/a^b) === (a, -b) # uses frankestein @test r1defslot(1/a^(b+2c)) === (a, -b-2c) # uses frankestein @test r1defslot(1/a^2) === (a, -2) # uses opposite_sign_matcher From 79118fc15877b6183ee4d5122fa9872fbe4d489b Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Wed, 18 Jun 2025 12:55:42 +0200 Subject: [PATCH 04/26] operation + and * are always commutative now --- src/matchers.jl | 216 ++++++++++++++++++++++-------------------------- 1 file changed, 101 insertions(+), 115 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 5d8aaa4b3..3fe3d0b14 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -41,6 +41,16 @@ function matcher(slot::Slot) end end +# this is called only when defslot_term_matcher finds the operation and tries +# to match it, so no default value used. So the same function as slot_matcher +# can be used +function matcher(defslot::DefSlot) + matcher(Slot(defslot.name, defslot.predicate)) +end + +# function opposite_sign_matcher(val::Any) +# end + function opposite_sign_matcher(slot::Slot) function slot_matcher(next, data, bindings) !islist(data) && return nothing @@ -59,13 +69,6 @@ function opposite_sign_matcher(defslot::DefSlot) opposite_sign_matcher(Slot(defslot.name, defslot.predicate)) end -# this is called only when defslot_term_matcher finds the operation and tries -# to match it, so no default value used. So the same function as slot_matcher -# can be used -function matcher(defslot::DefSlot) - matcher(Slot(defslot.name, defslot.predicate)) -end - # returns n == offset, 0 if failed function trymatchexpr(data, value, n) if !islist(value) @@ -124,138 +127,121 @@ end function term_matcher_constructor(term) matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) - - function term_matcher(success, data, bindings) - !islist(data) && return nothing # if data is not a list, return nothing - !iscall(car(data)) && return nothing # if first element is not a call, return nothing - - function loop(term, bindings′, matchers′) # Get it to compile faster - if !islist(matchers′) - if !islist(term) - return success(bindings′, 1) - end - return nothing - end - car(matchers′)(term, bindings′) do b, n - loop(drop_n(term, n), b, cdr(matchers′)) + + function loop(term, bindings′, matchers′) # Get it to compile faster + if !islist(matchers′) + if !islist(term) + return bindings′ end - # explanation of above 3 lines: - # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) - # <------ next(b,n) ----------------------------> - # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list - # Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term - # Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses - # the length of the list, is considered empty + return nothing end + car(matchers′)(term, bindings′) do b, n + loop(drop_n(term, n), b, cdr(matchers′)) + end + # explenation of above 3 lines: + # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) + # <------ next(b,n) ----------------------------> + # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list + # Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term + # Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses + # the length of the list, is considered empty + end - result = loop(car(data), bindings, matchers) - # if data is of the alternative form 1/(...)^(...), it might match with negative exponent - if operation(term)==^ - alternative_form = (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) - if result === nothing && alternative_form + # if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent + if operation(term)==^ + # the below 4 lines could stay in the function term_matcher_pow, but + # are here to speed up the rule matcher function + cond = isa(arguments(term)[2], Slot) || isa(arguments(term)[2], DefSlot) + if cond + matchers_modified = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? + end + + function term_matcher_pow(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing + + result = loop(car(data), bindings, matchers) + result !== nothing && return success(result, 1) + + # if data is of the alternative form 1/(...)^(...), it might match with negative exponent + if (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) denominator = arguments(car(data))[2] # let's say data = a^b with a and b can be whatever # if b is not a number then call the loop function with a^-b if !isa(arguments(denominator)[2], Number) frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) result = loop(frankestein, bindings, matchers) - else - # if b is a number, like 3, we cant call loop with a^-3 bc it - # will automatically transform into 1/a^3. Therfore we need to - # create a matcher that flips the sign of the exponent. I created - # this matecher just for `Slot`s and not for terms, because if b - # is a number and not a call, certainly doesn't match a term (I hope). - if isa(arguments(term)[2], Slot) - matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? - result = loop(denominator, bindings, matchers2) - end + # if b is a number, like 3, we cant call loop with a^-3 bc it + # will automatically transform into 1/a^3. Therfore we need to + # create a matcher that flips the sign of the exponent. I created + # this matecher just for `Slot`s and `DefSlot`s, but not for + # terms or literals, because if b is a number and not a call, + # certainly doesn't match a term (I hope). + # Also not a literal because...? + elseif cond + result = loop(denominator, bindings, matchers_modified) end end + if result !== nothing + return success(result, 1) + end + end + return term_matcher_pow + # if the operation is commutative + elseif operation(term) in [+, *] + all_matchers = [] + args = arguments(term) + for inds in permutations(eachindex(args), length(args)) + reord = @views args[inds] + push!(all_matchers, (matcher(operation(term)), map(matcher, reord)...,)) + end + + function term_matcher_comm(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing + + for m in all_matchers + result = loop(car(data), bindings, m) + result !== nothing && return success(result, 1) + end + end + return term_matcher_comm + else + function term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing + + result = loop(car(data), bindings, matchers) + if result !== nothing + return success(result, 1) + end end - result + return term_matcher end end # creates a matcher for a term containing a defslot, such as: # (~x + ...complicated pattern...) * ~!y # normal part (can bee a tree) operation defslot part - -# defslot_term_matcher works like this: -# checks whether data starts with the default operation. -# if yes (1): continues like term_matcher -# if no checks whether data matches the normal part -# if no returns nothing, rule is not applied -# if yes (3): adds the pair (default value name, default value) to the found bindings and -# calls the success function like term_matcher would do - function defslot_term_matcher_constructor(term) - a = arguments(term) # length two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments - matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term - + a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^ that accept two arguments defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term defslot = a[defslot_index] + other_part_matcher = matcher(defslot_index==1 ? a[2] : a[1]) # find the matcher of the normal part - function defslot_term_matcher(success, data, bindings) - # if data is not a list, return nothing - !islist(data) && return nothing - result = nothing - if iscall(car(data)) - # (1) - function loop(term, bindings′, matchers′) # Get it to compile faster - if !islist(matchers′) - if !islist(term) - return success(bindings′, 1) - end - return nothing - end - car(matchers′)(term, bindings′) do b, n - loop(drop_n(term, n), b, cdr(matchers′)) - end - end + normal_matcher = term_matcher_constructor(term) - result = loop(car(data), bindings, matchers) # Try to eat exactly one term - # if data is of the alternative form 1/(...)^(...), it might match with negative exponent - if operation(term)==^ - alternative_form = (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) - if result === nothing && alternative_form - denominator = arguments(car(data))[2] - # let's say data = a^b with a and b can be whatever - # if b is not a number then call the loop function with a^-b - if !isa(arguments(denominator)[2], Number) - frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) - result = loop(frankestein, bindings, matchers) - else - # if b is a number, like 3, we cant call loop with a^-3 bc it - # will automatically transform into 1/a^3. Therfore we need to - # create a matcher that flips the sign of the exponent. I created - # this matecher just for `DefSlot`s and not for terms, because if b - # is a number and not a call, certainly doesn't match a term (I hope). - if isa(arguments(term)[2], DefSlot) - matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? - result = loop(denominator, bindings, matchers2) - end - end - end - end - # (2) - if result !== nothing - return result - end - end - + function defslot_term_matcher(success, data, bindings) + result = normal_matcher(success, data, bindings) + result !== nothing && return result + # if no match, try to match with a defslot # if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation) - if ( !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation) ) - other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part - - # checks whether it matches the normal part - # <-----------------(2)-------------------------------> - bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) - - if bindings === nothing - return nothing - end - result = success(bindings, 1) + if ( !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation) ) + # checks wether it matches the normal part if yes executes (foo) + # (foo): adds the pair (default value name, default value) to the found bindings + # <------------------(foo)----------------------------> + result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) + result !== nothing && return success(result, 1) end - result end end From cd0cc33c2ab1605e8869f245901b10da779d3d53 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Wed, 18 Jun 2025 15:32:53 +0200 Subject: [PATCH 05/26] added some tests of commutative operations --- test/rewrite.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/rewrite.jl b/test/rewrite.jl index 194b9fe92..c8ed107cd 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -2,7 +2,7 @@ using SymbolicUtils include("utils.jl") -@syms a b c +@syms a b c x @testset "Equality" begin @eqtest a == a @@ -47,6 +47,20 @@ end @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) end +@testset "Commutative + and *" begin + r1 = @rule sin(~x) + cos(~x) => ~x + @test r1(sin(a)+cos(a)) === a + @test r1(sin(x)+cos(x)) === x + r2 = @rule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m) + r3 = @rule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m) + @test r2((a+b)*(x+c)^b) === (a, b, x, c, b) + @test r3((a+b)*(x+c)^b) === (a, b, x, c, b) + rPredicate1 = @rule ~x::(x->isa(x,Number)) + ~y => (~x, ~y) + rPredicate2 = @rule ~y + ~x::(x->isa(x,Number)) => (~x, ~y) + @test rPredicate1(2+x) === (2, x) + @test rPredicate2(2+x) === (2, x) +end + @testset "Slot matcher with default value" begin r_sum = @rule (~x + ~!y)^2 => ~y @test r_sum((a + b)^2) === b From bd06d79b6da661f1118b1ca6cabd3c4024090e80 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Thu, 19 Jun 2025 10:58:12 +0200 Subject: [PATCH 06/26] fixed bug on defslot functionality --- src/matchers.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 3fe3d0b14..239f8ca1e 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -13,10 +13,9 @@ function matcher(val::Any) # just two arguments bc defslot is only supported with operations with two args: *, ^, + if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val)) return defslot_term_matcher_constructor(val) - # else return a normal term matcher - else - return term_matcher_constructor(val) end + # else return a normal term matcher + return term_matcher_constructor(val) end function literal_matcher(next, data, bindings) @@ -232,6 +231,7 @@ function defslot_term_matcher_constructor(term) normal_matcher = term_matcher_constructor(term) function defslot_term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing result = normal_matcher(success, data, bindings) result !== nothing && return result # if no match, try to match with a defslot From a1da82db9a8b3a2e0114bcb31041a96eb18e0ba9 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Thu, 19 Jun 2025 15:37:20 +0200 Subject: [PATCH 07/26] added defslot on operations with multiple arguments --- src/matchers.jl | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 239f8ca1e..8f3b3e969 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -11,7 +11,7 @@ function matcher(val::Any) if iscall(val) # if has two arguments and one of them is a DefSlot, create a term matcher with defslot # just two arguments bc defslot is only supported with operations with two args: *, ^, + - if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val)) + if any(x -> isa(x, DefSlot), arguments(val)) return defslot_term_matcher_constructor(val) end # else return a normal term matcher @@ -188,13 +188,6 @@ function term_matcher_constructor(term) return term_matcher_pow # if the operation is commutative elseif operation(term) in [+, *] - all_matchers = [] - args = arguments(term) - for inds in permutations(eachindex(args), length(args)) - reord = @views args[inds] - push!(all_matchers, (matcher(operation(term)), map(matcher, reord)...,)) - end - function term_matcher_comm(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing @@ -223,10 +216,17 @@ end # (~x + ...complicated pattern...) * ~!y # normal part (can bee a tree) operation defslot part function defslot_term_matcher_constructor(term) - a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^ that accept two arguments + a = arguments(term) defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term defslot = a[defslot_index] - other_part_matcher = matcher(defslot_index==1 ? a[2] : a[1]) # find the matcher of the normal part + if length(a) == 2 + other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1]) + else + others = [a[i] for i in eachindex(a) if i != defslot_index] + T = symtype(term) + f = operation(term) + other_part_matcher = term_matcher_constructor(Term{T}(f, others)) + end normal_matcher = term_matcher_constructor(term) @@ -236,12 +236,14 @@ function defslot_term_matcher_constructor(term) result !== nothing && return result # if no match, try to match with a defslot # if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation) - if ( !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation) ) - # checks wether it matches the normal part if yes executes (foo) - # (foo): adds the pair (default value name, default value) to the found bindings - # <------------------(foo)----------------------------> - result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) - result !== nothing && return success(result, 1) - end + + # checks wether it matches the normal part if yes executes (foo) + # (foo): adds the pair (default value name, default value) to the found bindings + # <------------------(foo)----------------------------> + result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) + println(result) + result !== nothing && return success(result, 1) + + nothing end end From 7849e7a5b845e379c3f61a88d6aca0ff1660c28b Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Thu, 19 Jun 2025 17:51:25 +0200 Subject: [PATCH 08/26] moved the commutativiry checks to only acrule macro --- src/matchers.jl | 64 +++++++++++++++++++++++++------------------------ src/rule.jl | 48 ++++++++++++++++++++++++++++--------- test/rewrite.jl | 18 ++++++++------ 3 files changed, 81 insertions(+), 49 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 8f3b3e969..bb22ea273 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,16 +6,16 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # -function matcher(val::Any) +function matcher(val::Any; acSets = nothing) # if val is a call (like an operation) creates a term matcher or term matcher with defslot if iscall(val) # if has two arguments and one of them is a DefSlot, create a term matcher with defslot # just two arguments bc defslot is only supported with operations with two args: *, ^, + if any(x -> isa(x, DefSlot), arguments(val)) - return defslot_term_matcher_constructor(val) + return defslot_term_matcher_constructor(val, acSets) end # else return a normal term matcher - return term_matcher_constructor(val) + return term_matcher_constructor(val, acSets) end function literal_matcher(next, data, bindings) @@ -24,7 +24,8 @@ function matcher(val::Any) end end -function matcher(slot::Slot) +# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro +function matcher(slot::Slot; acSets = nothing) function slot_matcher(next, data, bindings) !islist(data) && return nothing val = get(bindings, slot.name, nothing) @@ -43,7 +44,7 @@ end # this is called only when defslot_term_matcher finds the operation and tries # to match it, so no default value used. So the same function as slot_matcher # can be used -function matcher(defslot::DefSlot) +function matcher(defslot::DefSlot; acSets = nothing) matcher(Slot(defslot.name, defslot.predicate)) end @@ -59,7 +60,7 @@ function opposite_sign_matcher(slot::Slot) return next(bindings, 1) end elseif slot.predicate(car(data)) - next(assoc(bindings, slot.name, -car(data)), 1) # this - is the only differenct wrt matcher(slot::Slot) + next(assoc(bindings, slot.name, -car(data)), 1) # this - is the only difference wrt matcher(slot::Slot) end end end @@ -96,7 +97,7 @@ function trymatchexpr(data, value, n) end end -function matcher(segment::Segment) +function matcher(segment::Segment; acSets=nothing) function segment_matcher(success, data, bindings) val = get(bindings, segment.name, nothing) @@ -124,8 +125,8 @@ function matcher(segment::Segment) end end -function term_matcher_constructor(term) - matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) +function term_matcher_constructor(term, acSets) + matchers = (matcher(operation(term); acSets=acSets), map(x->matcher(x;acSets=acSets), arguments(term))...,) function loop(term, bindings′, matchers′) # Get it to compile faster if !islist(matchers′) @@ -137,7 +138,7 @@ function term_matcher_constructor(term) car(matchers′)(term, bindings′) do b, n loop(drop_n(term, n), b, cdr(matchers′)) end - # explenation of above 3 lines: + # explanation of above 3 lines: # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) # <------ next(b,n) ----------------------------> # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list @@ -171,7 +172,7 @@ function term_matcher_constructor(term) frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) result = loop(frankestein, bindings, matchers) # if b is a number, like 3, we cant call loop with a^-3 bc it - # will automatically transform into 1/a^3. Therfore we need to + # will automatically transform into 1/a^3. Therefore we need to # create a matcher that flips the sign of the exponent. I created # this matecher just for `Slot`s and `DefSlot`s, but not for # terms or literals, because if b is a number and not a call, @@ -181,21 +182,27 @@ function term_matcher_constructor(term) result = loop(denominator, bindings, matchers_modified) end end - if result !== nothing - return success(result, 1) - end + result !== nothing && return success(result, 1) + return nothing end return term_matcher_pow # if the operation is commutative - elseif operation(term) in [+, *] + elseif acSets!==nothing && !isa(arguments(term)[1], Segment) && operation(term) in [+, *] function term_matcher_comm(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing - for m in all_matchers - result = loop(car(data), bindings, m) - result !== nothing && return success(result, 1) + T = symtype(car(data)) + f = operation(car(data)) + data_args = arguments(car(data)) + + for inds in acSets(eachindex(data_args), length(arguments(term))) + candidate = Term{T}(f, @views data_args[inds]) + + result = loop(candidate, bindings, matchers) + result !== nothing && length(data_args) == length(inds) && return success(result,1) end + return nothing end return term_matcher_comm else @@ -204,9 +211,8 @@ function term_matcher_constructor(term) !iscall(car(data)) && return nothing # if first element is not a call, return nothing result = loop(car(data), bindings, matchers) - if result !== nothing - return success(result, 1) - end + result !== nothing && return success(result, 1) + return nothing end return term_matcher end @@ -215,35 +221,31 @@ end # creates a matcher for a term containing a defslot, such as: # (~x + ...complicated pattern...) * ~!y # normal part (can bee a tree) operation defslot part -function defslot_term_matcher_constructor(term) +function defslot_term_matcher_constructor(term, acSets) a = arguments(term) defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term defslot = a[defslot_index] if length(a) == 2 - other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1]) + other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1]; acSets = acSets) else others = [a[i] for i in eachindex(a) if i != defslot_index] T = symtype(term) f = operation(term) - other_part_matcher = term_matcher_constructor(Term{T}(f, others)) + other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets) end - normal_matcher = term_matcher_constructor(term) + normal_matcher = term_matcher_constructor(term, acSets) function defslot_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing result = normal_matcher(success, data, bindings) result !== nothing && return result - # if no match, try to match with a defslot - # if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation) - - # checks wether it matches the normal part if yes executes (foo) + # if no match, try to match with a defslot. + # checks whether it matches the normal part if yes executes (foo) # (foo): adds the pair (default value name, default value) to the found bindings # <------------------(foo)----------------------------> result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) - println(result) result !== nothing && return success(result, 1) - nothing end end diff --git a/src/rule.jl b/src/rule.jl index fed5895b6..d645d6036 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -437,16 +437,46 @@ Rule(acr::ACRule) = acr.rule getdepth(r::ACRule) = getdepth(r.rule) macro acrule(expr) - arity = length(expr.args[2].args[2:end]) + @assert expr.head == :call && expr.args[1] == :(=>) + lhs = expr.args[2] + rhs = rewrite_rhs(expr.args[3]) + keys = Symbol[] + lhs_term = makepattern(lhs, keys) + unique!(keys) + + arity = length(lhs.args[2:end]) + quote - ACRule(permutations, $(esc(:(@rule($(expr))))), $arity) + $(__source__) + lhs_pattern = $(lhs_term) + rule = Rule($(QuoteNode(expr)), + lhs_pattern, + matcher(lhs_pattern; acSets = permutations), + __MATCHES__ -> $(makeconsequent(rhs)), + rule_depth($lhs_term)) + ACRule(permutations, rule, $arity) end end macro ordered_acrule(expr) - arity = length(expr.args[2].args[2:end]) + @assert expr.head == :call && expr.args[1] == :(=>) + lhs = expr.args[2] + rhs = rewrite_rhs(expr.args[3]) + keys = Symbol[] + lhs_term = makepattern(lhs, keys) + unique!(keys) + + arity = length(lhs.args[2:end]) + quote - ACRule(combinations, $(esc(:(@rule($(expr))))), $arity) + $(__source__) + lhs_pattern = $(lhs_term) + rule = Rule($(QuoteNode(expr)), + lhs_pattern, + matcher(lhs_pattern; acSets = combinations), + __MATCHES__ -> $(makeconsequent(rhs)), + rule_depth($lhs_term)) + ACRule(combinations, rule, $arity) end end @@ -454,15 +484,11 @@ Base.show(io::IO, acr::ACRule) = print(io, "ACRule(", acr.rule, ")") function (acr::ACRule)(term) r = Rule(acr) - if !iscall(term) + if !iscall(term) || operation(term) != operation(r.lhs) + # different operations -> try deflsot r(term) else - f = operation(term) - # Assume that the matcher was formed by closing over a term - if f != operation(r.lhs) # Maybe offer a fallback if m.term errors. - return nothing - end - + f = operation(term) T = symtype(term) args = arguments(term) diff --git a/test/rewrite.jl b/test/rewrite.jl index c8ed107cd..fd34ebf3d 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -48,17 +48,21 @@ end end @testset "Commutative + and *" begin - r1 = @rule sin(~x) + cos(~x) => ~x - @test r1(sin(a)+cos(a)) === a - @test r1(sin(x)+cos(x)) === x - r2 = @rule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m) - r3 = @rule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m) + r1 = @acrule exp(sin(~x) + cos(~x)) => ~x + @test r1(exp(sin(a)+cos(a))) === a + @test r1(exp(sin(x)+cos(x))) === x + r2 = @acrule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m) + r3 = @acrule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m) @test r2((a+b)*(x+c)^b) === (a, b, x, c, b) @test r3((a+b)*(x+c)^b) === (a, b, x, c, b) - rPredicate1 = @rule ~x::(x->isa(x,Number)) + ~y => (~x, ~y) - rPredicate2 = @rule ~y + ~x::(x->isa(x,Number)) => (~x, ~y) + rPredicate1 = @acrule ~x::(x->isa(x,Number)) + ~y => (~x, ~y) + rPredicate2 = @acrule ~y + ~x::(x->isa(x,Number)) => (~x, ~y) @test rPredicate1(2+x) === (2, x) @test rPredicate2(2+x) === (2, x) + r5 = @acrule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w) + r6 = @acrule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w) + @test r5(c*(a+b)+d) === (d, c, a, b) + @test r6(c*(a+b)+d) === (d, c, a, b) end @testset "Slot matcher with default value" begin From a7d57e957845565128bbc1bc7c210fba124bf8f9 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Fri, 20 Jun 2025 10:33:41 +0200 Subject: [PATCH 09/26] negative exponent feature is done in a different way, more clean --- src/matchers.jl | 50 +++++-------------------------------------------- test/rewrite.jl | 2 +- 2 files changed, 6 insertions(+), 46 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index bb22ea273..63787b83c 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -48,27 +48,6 @@ function matcher(defslot::DefSlot; acSets = nothing) matcher(Slot(defslot.name, defslot.predicate)) end -# function opposite_sign_matcher(val::Any) -# end - -function opposite_sign_matcher(slot::Slot) - function slot_matcher(next, data, bindings) - !islist(data) && return nothing - val = get(bindings, slot.name, nothing) - if val !== nothing - if isequal(val, car(data)) - return next(bindings, 1) - end - elseif slot.predicate(car(data)) - next(assoc(bindings, slot.name, -car(data)), 1) # this - is the only difference wrt matcher(slot::Slot) - end - end -end - -function opposite_sign_matcher(defslot::DefSlot) - opposite_sign_matcher(Slot(defslot.name, defslot.predicate)) -end - # returns n == offset, 0 if failed function trymatchexpr(data, value, n) if !islist(value) @@ -149,13 +128,6 @@ function term_matcher_constructor(term, acSets) # if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent if operation(term)==^ - # the below 4 lines could stay in the function term_matcher_pow, but - # are here to speed up the rule matcher function - cond = isa(arguments(term)[2], Slot) || isa(arguments(term)[2], DefSlot) - if cond - matchers_modified = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? - end - function term_matcher_pow(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing @@ -166,27 +138,15 @@ function term_matcher_constructor(term, acSets) # if data is of the alternative form 1/(...)^(...), it might match with negative exponent if (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) denominator = arguments(car(data))[2] - # let's say data = a^b with a and b can be whatever - # if b is not a number then call the loop function with a^-b - if !isa(arguments(denominator)[2], Number) - frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) - result = loop(frankestein, bindings, matchers) - # if b is a number, like 3, we cant call loop with a^-3 bc it - # will automatically transform into 1/a^3. Therefore we need to - # create a matcher that flips the sign of the exponent. I created - # this matecher just for `Slot`s and `DefSlot`s, but not for - # terms or literals, because if b is a number and not a call, - # certainly doesn't match a term (I hope). - # Also not a literal because...? - elseif cond - result = loop(denominator, bindings, matchers_modified) - end + T = symtype(denominator) + frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) end - result !== nothing && return success(result, 1) return nothing end return term_matcher_pow - # if the operation is commutative + # if we want to do commutative checks, i.e. call matcher with different order of the arguments elseif acSets!==nothing && !isa(arguments(term)[1], Segment) && operation(term) in [+, *] function term_matcher_comm(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing diff --git a/test/rewrite.jl b/test/rewrite.jl index fd34ebf3d..66a51cb8e 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -2,7 +2,7 @@ using SymbolicUtils include("utils.jl") -@syms a b c x +@syms a b c d x @testset "Equality" begin @eqtest a == a From 50b5e501e590e1b9fbc4b32674dff5167a6706b4 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Fri, 20 Jun 2025 11:33:13 +0200 Subject: [PATCH 10/26] fixed failing ci tests --- src/matchers.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 63787b83c..c24e12027 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -127,17 +127,18 @@ function term_matcher_constructor(term, acSets) end # if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent - if operation(term)==^ - function term_matcher_pow(success, data, bindings) + if operation(term) === ^ + function pow_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing - !iscall(car(data)) && return nothing # if first element is not a call, return nothing + data = car(data) # from (..., ) to ... + !iscall(data) && return nothing # if first element is not a call, return nothing - result = loop(car(data), bindings, matchers) + result = loop(data, bindings, matchers) result !== nothing && return success(result, 1) # if data is of the alternative form 1/(...)^(...), it might match with negative exponent - if (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) - denominator = arguments(car(data))[2] + if (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^) + denominator = arguments(data)[2] T = symtype(denominator) frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]]) result = loop(frankestein, bindings, matchers) @@ -145,10 +146,10 @@ function term_matcher_constructor(term, acSets) end return nothing end - return term_matcher_pow + return pow_term_matcher # if we want to do commutative checks, i.e. call matcher with different order of the arguments elseif acSets!==nothing && !isa(arguments(term)[1], Segment) && operation(term) in [+, *] - function term_matcher_comm(success, data, bindings) + function commutative_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing @@ -164,7 +165,7 @@ function term_matcher_constructor(term, acSets) end return nothing end - return term_matcher_comm + return commutative_term_matcher else function term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing From 3bd128263f6a5540d7e90c67ed20a63a67f9a546 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Fri, 20 Jun 2025 11:50:39 +0200 Subject: [PATCH 11/26] added tests with deflost in operation call with more than two arguments --- test/rewrite.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/rewrite.jl b/test/rewrite.jl index 66a51cb8e..41f0cbdb1 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -98,6 +98,15 @@ end @test r_mix((a + b*c)^2) === (2, c) @test r_mix((a + b*c)) === (1, c) @test r_mix((a + b)) === (1, 1) + + r_more_than_two_arguments = @rule (~!a)*exp(~x)*sin(~x) => (~a, ~x) + @test r_more_than_two_arguments(sin(x)*exp(x)) === (1, x) + @test r_more_than_two_arguments(sin(x)*exp(x)*a) === (a, x) + + r_mixmix = @rule (~!a)*exp(~x)*sin(~!b + (~x)^2 + ~x) => (~a, ~b, ~x) + @test r_mixmix(exp(x)*sin(1+x+x^2)*2) === (2, 1, x) + @test r_mixmix(exp(x)*sin(x+x^2)*2) === (2, 0, x) + @test r_mixmix(exp(x)*sin(x+x^2)) === (1, 0, x) end @testset "1/power matches power with exponent of opposite sign" begin From 6825df3c7fa7e147fe6238df4bbb45efc01a14cf Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sat, 21 Jun 2025 13:44:14 +0200 Subject: [PATCH 12/26] now rationals can be used in rules --- src/rule.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/rule.jl b/src/rule.jl index d645d6036..bfcaf975a 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -130,6 +130,9 @@ function makepattern(expr, keys, parentCall=nothing) # matches ~x::predicate makeslot(expr.args[2], keys) end + elseif expr.args[1] === :(//) + # bc when the expression is not quoted, 3//2 is a Rational{Int64}, not a call + return esc(expr.args[2] // expr.args[3]) else # make a pattern for every argument of the expr. :(term($(map(x->makepattern(x, keys, operation(expr)), expr.args)...); type=Any)) From e6bce154390c49a198664cd245384b7b9ca7d248 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sun, 22 Jun 2025 11:30:50 +0200 Subject: [PATCH 13/26] created smrule (sum multiplication rule) macro --- src/SymbolicUtils.jl | 2 +- src/rule.jl | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 60b3fdf71..e44e2faa4 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -54,7 +54,7 @@ export Rewriters # A library for composing together expr -> expr functions using Combinatorics: permutations, combinations -export @rule, @acrule, RuleSet +export @rule, @acrule, @smrule, RuleSet # Rule type and @rule macro include("rule.jl") diff --git a/src/rule.jl b/src/rule.jl index bfcaf975a..403c402ce 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -384,6 +384,26 @@ macro rule(expr) end end +macro smrule(expr) + @assert expr.head == :call && expr.args[1] == :(=>) + lhs = expr.args[2] + rhs = rewrite_rhs(expr.args[3]) + keys = Symbol[] + lhs_term = makepattern(lhs, keys) + unique!(keys) + quote + $(__source__) + lhs_pattern = $(lhs_term) + Rule( + $(QuoteNode(expr)), + lhs_pattern, + matcher(lhs_pattern; acSets = permutations), + __MATCHES__ -> $(makeconsequent(rhs)), + rule_depth($lhs_term) + ) + end +end + """ @capture ex pattern From f8c88413e9dc47a48aeb51684aa98eb31ffb1397 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sun, 22 Jun 2025 14:21:37 +0200 Subject: [PATCH 14/26] enhance commutative term matcher to validate operation type --- src/matchers.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/matchers.jl b/src/matchers.jl index c24e12027..b8a784418 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -152,6 +152,7 @@ function term_matcher_constructor(term, acSets) function commutative_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing + operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try T = symtype(car(data)) f = operation(car(data)) From e742a843d05570c2b0dbf4466ae0df366b8eb510 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sun, 22 Jun 2025 16:07:56 +0200 Subject: [PATCH 15/26] fixed bug in defslot code and improved performance --- src/matchers.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index b8a784418..4bd1d01b4 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -200,12 +200,14 @@ function defslot_term_matcher_constructor(term, acSets) function defslot_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing - result = normal_matcher(success, data, bindings) - result !== nothing && return result + # call the normal mathcer, with succes function foo1 that simply returns the bindings + # <--foo1--> + result = normal_matcher((b,n) -> b, data, bindings) + result !== nothing && return success(result, 1) # if no match, try to match with a defslot. - # checks whether it matches the normal part if yes executes (foo) - # (foo): adds the pair (default value name, default value) to the found bindings - # <------------------(foo)----------------------------> + # checks whether it matches the normal part if yes executes foo2 + # foo2: adds the pair (default value name, default value) to the found bindings + # <-------------------foo2----------------------------> result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) result !== nothing && return success(result, 1) nothing From 9e4596d11521f0dd854804ea228628a6547b6b1b Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sun, 22 Jun 2025 16:18:34 +0200 Subject: [PATCH 16/26] improved negative exponent pattern matching. now it matches also for expressions like (1/...)^(...) --- src/matchers.jl | 10 +++++++++- test/rewrite.jl | 5 ++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 4bd1d01b4..6541f5166 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -136,6 +136,14 @@ function term_matcher_constructor(term, acSets) result = loop(data, bindings, matchers) result !== nothing && return success(result, 1) + # if data is of the alternative form (1/...)^(...), it might match with negative exponent + if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1) + one_over_smth = arguments(data)[1] + T = symtype(one_over_smth) + frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end # if data is of the alternative form 1/(...)^(...), it might match with negative exponent if (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^) denominator = arguments(data)[2] @@ -200,7 +208,7 @@ function defslot_term_matcher_constructor(term, acSets) function defslot_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing - # call the normal mathcer, with succes function foo1 that simply returns the bindings + # call the normal matcher, with success function foo1 that simply returns the bindings # <--foo1--> result = normal_matcher((b,n) -> b, data, bindings) result !== nothing && return success(result, 1) diff --git a/test/rewrite.jl b/test/rewrite.jl index 41f0cbdb1..3beb9c3fa 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -109,7 +109,7 @@ end @test r_mixmix(exp(x)*sin(x+x^2)) === (1, 0, x) end -@testset "1/power matches power with exponent of opposite sign" begin +@testset "power matcher with negative exponent" begin r1 = @rule (~x)^(~y) => (~x, ~y) # rule with slot as exponent @test r1(1/a^b) === (a, -b) # uses frankestein @test r1(1/a^(b+2c)) === (a, -b-2c) # uses frankestein @@ -124,6 +124,9 @@ end @test r1defslot(1/a^(b+2c)) === (a, -b-2c) # uses frankestein @test r1defslot(1/a^2) === (a, -2) # uses opposite_sign_matcher @test r1defslot(a) === (a, 1) + + r = @rule (~x + ~y)^(~m) => (~x, ~y, ~m) # rule to match (1/...)^(...) + @test r((1/(a+b))^3) === (a,b,-3) end using SymbolicUtils: @capture From bdce8c4a02a6eeb5efd4d2ca63d4deab19cf3785 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Tue, 24 Jun 2025 17:56:36 +0200 Subject: [PATCH 17/26] changed order of checks in pow term matcher --- src/matchers.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 6541f5166..d0611de7b 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -133,9 +133,6 @@ function term_matcher_constructor(term, acSets) data = car(data) # from (..., ) to ... !iscall(data) && return nothing # if first element is not a call, return nothing - result = loop(data, bindings, matchers) - result !== nothing && return success(result, 1) - # if data is of the alternative form (1/...)^(...), it might match with negative exponent if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1) one_over_smth = arguments(data)[1] @@ -144,6 +141,10 @@ function term_matcher_constructor(term, acSets) result = loop(frankestein, bindings, matchers) result !== nothing && return success(result, 1) end + + result = loop(data, bindings, matchers) + result !== nothing && return success(result, 1) + # if data is of the alternative form 1/(...)^(...), it might match with negative exponent if (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^) denominator = arguments(data)[2] From 8c8a207da975f342bceee4e9d917afad974b7f41 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Fri, 27 Jun 2025 08:44:00 +0200 Subject: [PATCH 18/26] added match for exp and sqrt calls --- src/matchers.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/matchers.jl b/src/matchers.jl index d0611de7b..009ca981a 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -153,6 +153,23 @@ function term_matcher_constructor(term, acSets) result = loop(frankestein, bindings, matchers) result !== nothing && return success(result, 1) end + + # if data is a exp call, it might match with base e + if operation(data)===exp + T = symtype(arguments(data)[1]) + frankestein = Term{T}(^,[ℯ,arguments(data)[1]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end + + # if data is a sqrt call, it might match with exponent 1//2 + if operation(data)===sqrt + T = symtype(arguments(data)[1]) + frankestein = Term{T}(^,[arguments(data)[1], 1//2]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end + return nothing end return pow_term_matcher From 08e99933cbe5db4c53e0b5fff4d081f9625e841c Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Mon, 30 Jun 2025 21:14:13 +0200 Subject: [PATCH 19/26] removed smrule macro and added commutativity checks to the rule macro --- src/SymbolicUtils.jl | 2 +- src/matchers.jl | 37 ++++++++++++++++++++++--------------- src/rule.jl | 26 ++++---------------------- test/rewrite.jl | 15 ++++++++------- 4 files changed, 35 insertions(+), 45 deletions(-) diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index e44e2faa4..60b3fdf71 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -54,7 +54,7 @@ export Rewriters # A library for composing together expr -> expr functions using Combinatorics: permutations, combinations -export @rule, @acrule, @smrule, RuleSet +export @rule, @acrule, RuleSet # Rule type and @rule macro include("rule.jl") diff --git a/src/matchers.jl b/src/matchers.jl index 009ca981a..6973cc9b7 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,7 +6,7 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # -function matcher(val::Any; acSets = nothing) +function matcher(val::Any, acSets) # if val is a call (like an operation) creates a term matcher or term matcher with defslot if iscall(val) # if has two arguments and one of them is a DefSlot, create a term matcher with defslot @@ -25,7 +25,7 @@ function matcher(val::Any; acSets = nothing) end # acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro -function matcher(slot::Slot; acSets = nothing) +function matcher(slot::Slot, acSets) function slot_matcher(next, data, bindings) !islist(data) && return nothing val = get(bindings, slot.name, nothing) @@ -44,8 +44,8 @@ end # this is called only when defslot_term_matcher finds the operation and tries # to match it, so no default value used. So the same function as slot_matcher # can be used -function matcher(defslot::DefSlot; acSets = nothing) - matcher(Slot(defslot.name, defslot.predicate)) +function matcher(defslot::DefSlot, acSets) + matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets end # returns n == offset, 0 if failed @@ -76,7 +76,7 @@ function trymatchexpr(data, value, n) end end -function matcher(segment::Segment; acSets=nothing) +function matcher(segment::Segment, acSets) function segment_matcher(success, data, bindings) val = get(bindings, segment.name, nothing) @@ -105,7 +105,7 @@ function matcher(segment::Segment; acSets=nothing) end function term_matcher_constructor(term, acSets) - matchers = (matcher(operation(term); acSets=acSets), map(x->matcher(x;acSets=acSets), arguments(term))...,) + matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,) function loop(term, bindings′, matchers′) # Get it to compile faster if !islist(matchers′) @@ -181,14 +181,21 @@ function term_matcher_constructor(term, acSets) operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try T = symtype(car(data)) - f = operation(car(data)) - data_args = arguments(car(data)) - - for inds in acSets(eachindex(data_args), length(arguments(term))) - candidate = Term{T}(f, @views data_args[inds]) - - result = loop(candidate, bindings, matchers) - result !== nothing && length(data_args) == length(inds) && return success(result,1) + if T <: Number + f = operation(car(data)) + data_args = arguments(car(data)) + + for inds in acSets(eachindex(data_args), length(arguments(term))) + candidate = Term{T}(f, @views data_args[inds]) + + result = loop(candidate, bindings, matchers) + result !== nothing && length(data_args) == length(inds) && return success(result,1) + end + # if car(data) does not subtype to number, it might not be commutative + else + # call the normal matcher + result = loop(car(data), bindings, matchers) + result !== nothing && return success(result, 1) end return nothing end @@ -214,7 +221,7 @@ function defslot_term_matcher_constructor(term, acSets) defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term defslot = a[defslot_index] if length(a) == 2 - other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1]; acSets = acSets) + other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets) else others = [a[i] for i in eachindex(a) if i != defslot_index] T = symtype(term) diff --git a/src/rule.jl b/src/rule.jl index 403c402ce..556400356 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -367,24 +367,6 @@ _In the consequent pattern_: Use `(@ctx)` to access the context object on the ri of an expression. """ macro rule(expr) - @assert expr.head == :call && expr.args[1] == :(=>) - lhs = expr.args[2] - rhs = rewrite_rhs(expr.args[3]) - keys = Symbol[] - lhs_term = makepattern(lhs, keys) - unique!(keys) - quote - $(__source__) - lhs_pattern = $(lhs_term) - Rule($(QuoteNode(expr)), - lhs_pattern, - matcher(lhs_pattern), - __MATCHES__ -> $(makeconsequent(rhs)), - rule_depth($lhs_term)) - end -end - -macro smrule(expr) @assert expr.head == :call && expr.args[1] == :(=>) lhs = expr.args[2] rhs = rewrite_rhs(expr.args[3]) @@ -397,7 +379,7 @@ macro smrule(expr) Rule( $(QuoteNode(expr)), lhs_pattern, - matcher(lhs_pattern; acSets = permutations), + matcher(lhs_pattern, permutations), __MATCHES__ -> $(makeconsequent(rhs)), rule_depth($lhs_term) ) @@ -435,7 +417,7 @@ macro capture(ex, lhs) lhs_pattern = $(lhs_term) __MATCHES__ = Rule($(QuoteNode(lhs)), lhs_pattern, - matcher(lhs_pattern), + matcher(lhs_pattern, nothing), identity, rule_depth($lhs_term))($(esc(ex))) if __MATCHES__ !== nothing @@ -474,7 +456,7 @@ macro acrule(expr) lhs_pattern = $(lhs_term) rule = Rule($(QuoteNode(expr)), lhs_pattern, - matcher(lhs_pattern; acSets = permutations), + matcher(lhs_pattern, permutations), __MATCHES__ -> $(makeconsequent(rhs)), rule_depth($lhs_term)) ACRule(permutations, rule, $arity) @@ -496,7 +478,7 @@ macro ordered_acrule(expr) lhs_pattern = $(lhs_term) rule = Rule($(QuoteNode(expr)), lhs_pattern, - matcher(lhs_pattern; acSets = combinations), + matcher(lhs_pattern, combinations), __MATCHES__ -> $(makeconsequent(rhs)), rule_depth($lhs_term)) ACRule(combinations, rule, $arity) diff --git a/test/rewrite.jl b/test/rewrite.jl index 3beb9c3fa..f7d4c5f85 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -48,19 +48,20 @@ end end @testset "Commutative + and *" begin - r1 = @acrule exp(sin(~x) + cos(~x)) => ~x + r1 = @rule exp(sin(~x) + cos(~x)) => ~x + # using a or x changes the order of the arguments in the call @test r1(exp(sin(a)+cos(a))) === a @test r1(exp(sin(x)+cos(x))) === x - r2 = @acrule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m) - r3 = @acrule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m) + r2 = @rule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m) + r3 = @rule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m) @test r2((a+b)*(x+c)^b) === (a, b, x, c, b) @test r3((a+b)*(x+c)^b) === (a, b, x, c, b) - rPredicate1 = @acrule ~x::(x->isa(x,Number)) + ~y => (~x, ~y) - rPredicate2 = @acrule ~y + ~x::(x->isa(x,Number)) => (~x, ~y) + rPredicate1 = @rule ~x::(x->isa(x,Number)) + ~y => (~x, ~y) + rPredicate2 = @rule ~y + ~x::(x->isa(x,Number)) => (~x, ~y) @test rPredicate1(2+x) === (2, x) @test rPredicate2(2+x) === (2, x) - r5 = @acrule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w) - r6 = @acrule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w) + r5 = @rule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w) + r6 = @rule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w) @test r5(c*(a+b)+d) === (d, c, a, b) @test r6(c*(a+b)+d) === (d, c, a, b) end From 80cabb18862f3b323171498011f55938289e464b Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Mon, 7 Jul 2025 15:30:47 +0200 Subject: [PATCH 20/26] added commutativity checks also for segment matcher --- src/matchers.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 6973cc9b7..22cfc66e8 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -91,12 +91,9 @@ function matcher(segment::Segment, acSets) for i=length(data):-1:0 subexpr = take_n(data, i) - if segment.predicate(subexpr) - res = success(assoc(bindings, segment.name, subexpr), i) - if res !== nothing - break - end - end + !segment.predicate(subexpr) && continue + res = success(assoc(bindings, segment.name, subexpr), i) + res !== nothing && break end return res @@ -174,7 +171,7 @@ function term_matcher_constructor(term, acSets) end return pow_term_matcher # if we want to do commutative checks, i.e. call matcher with different order of the arguments - elseif acSets!==nothing && !isa(arguments(term)[1], Segment) && operation(term) in [+, *] + elseif acSets!==nothing && operation(term) in [+, *] function commutative_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing @@ -185,11 +182,11 @@ function term_matcher_constructor(term, acSets) f = operation(car(data)) data_args = arguments(car(data)) - for inds in acSets(eachindex(data_args), length(arguments(term))) + for inds in acSets(eachindex(data_args), length(data_args)) candidate = Term{T}(f, @views data_args[inds]) result = loop(candidate, bindings, matchers) - result !== nothing && length(data_args) == length(inds) && return success(result,1) + result !== nothing && return success(result,1) end # if car(data) does not subtype to number, it might not be commutative else From 2dbff775da9f2feda8d0cdfd32899a0590e5c489 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Mon, 7 Jul 2025 17:18:47 +0200 Subject: [PATCH 21/26] fixed predicates with defslots --- src/matchers.jl | 2 +- src/rule.jl | 2 +- test/rewrite.jl | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 22cfc66e8..5b0ad5a16 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -184,7 +184,7 @@ function term_matcher_constructor(term, acSets) for inds in acSets(eachindex(data_args), length(data_args)) candidate = Term{T}(f, @views data_args[inds]) - + result = loop(candidate, bindings, matchers) result !== nothing && return success(result,1) end diff --git a/src/rule.jl b/src/rule.jl index 556400356..2c2689ea8 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -82,7 +82,7 @@ function makeDefSlot(s::Expr, keys, op) push!(keys, name) tmp = defaultValOfCall(op) - :(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op))), $(esc(tmp))) + :(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op)), $(esc(tmp)))) end diff --git a/test/rewrite.jl b/test/rewrite.jl index f7d4c5f85..ebd084e76 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -108,6 +108,12 @@ end @test r_mixmix(exp(x)*sin(1+x+x^2)*2) === (2, 1, x) @test r_mixmix(exp(x)*sin(x+x^2)*2) === (2, 0, x) @test r_mixmix(exp(x)*sin(x+x^2)) === (1, 0, x) + + r_predicate = @rule ~x + (~!m::(var->isa(var, Int))) => (~x, ~m) + @test r_predicate(x+2) === (x, 2) + @test r_predicate(x+2.0) !== (x, 2.0) + # Note: r_predicate(x+2.0) doesnt return nothing, but (x+2.0, 0) + # becasue of the defslot end @testset "power matcher with negative exponent" begin From 734d1b9b772bac51d8cb19d4c051818a7f3f73b5 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Sun, 3 Aug 2025 23:19:01 +0200 Subject: [PATCH 22/26] now the pattern ~x^~m matches 1/x with m=-1 --- src/matchers.jl | 9 +++++++++ test/rewrite.jl | 1 + 2 files changed, 10 insertions(+) diff --git a/src/matchers.jl b/src/matchers.jl index 5b0ad5a16..fea162494 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -150,6 +150,15 @@ function term_matcher_constructor(term, acSets) result = loop(frankestein, bindings, matchers) result !== nothing && return success(result, 1) end + + # if data is of the alternative form 1/(...), it might match with exponent = -1 + if (operation(data) === /) && isequal(arguments(data)[1], 1) + denominator = arguments(data)[2] + T = symtype(denominator) + frankestein = Term{T}(^, [denominator, -1]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end # if data is a exp call, it might match with base e if operation(data)===exp diff --git a/test/rewrite.jl b/test/rewrite.jl index ebd084e76..b99dcb1f8 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -121,6 +121,7 @@ end @test r1(1/a^b) === (a, -b) # uses frankestein @test r1(1/a^(b+2c)) === (a, -b-2c) # uses frankestein @test r1(1/a^2) === (a, -2) # uses opposite_sign_matcher + @test r1(1/a) === (a, -1) r2 = @rule (~x)^(~y + ~z) => (~x, ~y, ~z) # rule with term as exponent @test r2(1/a^(b+2c)) === (a, -b, -2c) # uses frankestein From 4a49b1906e4cf262a343dde6921396f9972dbc10 Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Wed, 6 Aug 2025 10:59:05 +0200 Subject: [PATCH 23/26] added tests for power match with sqrt and exp functions --- test/rewrite.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/rewrite.jl b/test/rewrite.jl index b99dcb1f8..8de238a4f 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -137,6 +137,12 @@ end @test r((1/(a+b))^3) === (a,b,-3) end +@testset "special power matches" begin + r1 = @rule (~x)^(~y) => (~x, ~y) + @test r1(exp(a)) === (ℯ, a) # uses exp_matcher + @test r1(sqrt(a)) === (a, 1//2) # uses sqrt_matcher +end + using SymbolicUtils: @capture @testset "Capture form" begin From 05a5af2fa05dafaa207b03de2aae2332e36e688d Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Wed, 6 Aug 2025 11:35:03 +0200 Subject: [PATCH 24/26] refactor --- src/matchers.jl | 46 +++++++++++++++++----------------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index fea162494..a0725af83 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -130,48 +130,36 @@ function term_matcher_constructor(term, acSets) data = car(data) # from (..., ) to ... !iscall(data) && return nothing # if first element is not a call, return nothing - # if data is of the alternative form (1/...)^(...), it might match with negative exponent + result = loop(data, bindings, matchers) + result !== nothing && return success(result, 1) + + frankestein = nothing if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1) + # if data is of the alternative form (1/...)^(...) one_over_smth = arguments(data)[1] T = symtype(one_over_smth) frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]]) - result = loop(frankestein, bindings, matchers) - result !== nothing && return success(result, 1) - end - - result = loop(data, bindings, matchers) - result !== nothing && return success(result, 1) - - # if data is of the alternative form 1/(...)^(...), it might match with negative exponent - if (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^) + elseif (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^) + # if data is of the alternative form 1/(...)^(...) denominator = arguments(data)[2] T = symtype(denominator) frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]]) - result = loop(frankestein, bindings, matchers) - result !== nothing && return success(result, 1) - end - - # if data is of the alternative form 1/(...), it might match with exponent = -1 - if (operation(data) === /) && isequal(arguments(data)[1], 1) + elseif (operation(data) === /) && isequal(arguments(data)[1], 1) + # if data is of the alternative form 1/(...), it might match with exponent = -1 denominator = arguments(data)[2] T = symtype(denominator) frankestein = Term{T}(^, [denominator, -1]) - result = loop(frankestein, bindings, matchers) - result !== nothing && return success(result, 1) - end - - # if data is a exp call, it might match with base e - if operation(data)===exp + elseif operation(data)===exp + # if data is a exp call, it might match with base e T = symtype(arguments(data)[1]) - frankestein = Term{T}(^,[ℯ,arguments(data)[1]]) - result = loop(frankestein, bindings, matchers) - result !== nothing && return success(result, 1) - end - - # if data is a sqrt call, it might match with exponent 1//2 - if operation(data)===sqrt + frankestein = Term{T}(^,[ℯ, arguments(data)[1]]) + elseif operation(data)===sqrt + # if data is a sqrt call, it might match with exponent 1//2 T = symtype(arguments(data)[1]) frankestein = Term{T}(^,[arguments(data)[1], 1//2]) + end + + if frankestein !==nothing result = loop(frankestein, bindings, matchers) result !== nothing && return success(result, 1) end From 36034e0a68eecf73b5f00bb937b0149256d4e90a Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Fri, 8 Aug 2025 14:26:28 +0200 Subject: [PATCH 25/26] =?UTF-8?q?now=20...^(1//2)=20matches=20in=20the=20r?= =?UTF-8?q?ule=20with=20sqrt,=20and=20=E2=84=AF^...=20matches=20in=20the?= =?UTF-8?q?=20rule=20with=20exp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/matchers.jl | 40 ++++++++++++++++++++++++++++++++++++++++ test/rewrite.jl | 12 +++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/matchers.jl b/src/matchers.jl index a0725af83..ec55de3bb 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -194,6 +194,46 @@ function term_matcher_constructor(term, acSets) return nothing end return commutative_term_matcher + # if the operation is sqrt, we have to match also ^(1//2) + elseif operation(term)==sqrt + function sqrt_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + data = car(data) + !iscall(data) && return nothing # if first element is not a call, return nothing + + # do the normal matcher + result = loop(data, bindings, matchers) + result !== nothing && return success(result, 1) + + if (operation(data) === ^) && (arguments(data)[2] === 1//2) + T = symtype(arguments(data)[1]) + frankestein = Term{T}(sqrt,[arguments(data)[1]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end + return nothing + end + return sqrt_matcher + # if the operation is exp, we have to match also ℯ^ + elseif operation(term)==exp + function exp_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + data = car(data) + !iscall(data) && return nothing # if first element is not a call, return nothing + + # do the normal matcher + result = loop(data, bindings, matchers) + result !== nothing && return success(result, 1) + + if (operation(data) === ^) && (arguments(data)[1] === ℯ) + T = symtype(arguments(data)[2]) + frankestein = Term{T}(exp,[arguments(data)[2]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end + return nothing + end + return exp_matcher else function term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing diff --git a/test/rewrite.jl b/test/rewrite.jl index 8de238a4f..0d228af8f 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -113,7 +113,7 @@ end @test r_predicate(x+2) === (x, 2) @test r_predicate(x+2.0) !== (x, 2.0) # Note: r_predicate(x+2.0) doesnt return nothing, but (x+2.0, 0) - # becasue of the defslot + # because of the defslot end @testset "power matcher with negative exponent" begin @@ -143,6 +143,16 @@ end @test r1(sqrt(a)) === (a, 1//2) # uses sqrt_matcher end +@testset "Alternate form of special functions" begin + rsqrt = @rule sqrt(~x) => ~x + @test rsqrt(sqrt(x))===x + @test rsqrt((x)^(1//2))===x + + rexp = @rule exp(~x) => ~x + @test rexp(exp(x)) === x + @test rexp(ℯ^x) === x +end + using SymbolicUtils: @capture @testset "Capture form" begin From 9142ba0c8cf9873867bf3bac24d8b7ae90c9a95b Mon Sep 17 00:00:00 2001 From: Bumblebee00 Date: Fri, 8 Aug 2025 16:25:26 +0200 Subject: [PATCH 26/26] first prototype --- src/matchers.jl | 119 +++++++++++++++++++++++------------------------- src/rule.jl | 24 +++++----- test/rewrite.jl | 19 ++++---- 3 files changed, 82 insertions(+), 80 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index ec55de3bb..f61ae5ab1 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,16 +6,16 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # -function matcher(val::Any, acSets) +function matcher(val::Any, acSets, condition) # if val is a call (like an operation) creates a term matcher or term matcher with defslot if iscall(val) # if has two arguments and one of them is a DefSlot, create a term matcher with defslot # just two arguments bc defslot is only supported with operations with two args: *, ^, + if any(x -> isa(x, DefSlot), arguments(val)) - return defslot_term_matcher_constructor(val, acSets) + return defslot_term_matcher_constructor(val, acSets, condition) end # else return a normal term matcher - return term_matcher_constructor(val, acSets) + return term_matcher_constructor(val, acSets, condition) end function literal_matcher(next, data, bindings) @@ -24,8 +24,9 @@ function matcher(val::Any, acSets) end end -# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro -function matcher(slot::Slot, acSets) +# acSets and condition are not used but needs to be there in case +# matcher(::Slot) is directly called from the macro +function matcher(slot::Slot, acSets, condition) function slot_matcher(next, data, bindings) !islist(data) && return nothing val = get(bindings, slot.name, nothing) @@ -36,6 +37,7 @@ function matcher(slot::Slot, acSets) end # elseif the first element of data matches the slot predicate, add it to bindings and call next elseif slot.predicate(car(data)) + # println("slot of $slot matched") next(assoc(bindings, slot.name, car(data)), 1) end end @@ -44,8 +46,8 @@ end # this is called only when defslot_term_matcher finds the operation and tries # to match it, so no default value used. So the same function as slot_matcher # can be used -function matcher(defslot::DefSlot, acSets) - matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets +function matcher(defslot::DefSlot, acSets, condition) + matcher(Slot(defslot.name, defslot.predicate), nothing, nothing) end # returns n == offset, 0 if failed @@ -101,8 +103,11 @@ function matcher(segment::Segment, acSets) end end -function term_matcher_constructor(term, acSets) - matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,) +function term_matcher_constructor(term, acSets, condition) + matchers = ( + matcher(operation(term), acSets, condition), + map(x->matcher(x,acSets, condition), arguments(term))..., + ) function loop(term, bindings′, matchers′) # Get it to compile faster if !islist(matchers′) @@ -123,15 +128,32 @@ function term_matcher_constructor(term, acSets) # the length of the list, is considered empty end + # if condition errors, this means not all the bindings + # are associated, so we are not at the end of the match. So + # we continue to the next matchers + function check_conditions(result) + result === nothing && return false + try + tmp = condition(result) + # tmp==nothing means no conditions are present + tmp===nothing && return true + return tmp + catch e + # println("condition failed, continuing") + return true + end + end + # if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent if operation(term) === ^ function pow_term_matcher(success, data, bindings) + # println("in ^ matcher of $term with data $data") !islist(data) && return nothing # if data is not a list, return nothing data = car(data) # from (..., ) to ... !iscall(data) && return nothing # if first element is not a call, return nothing result = loop(data, bindings, matchers) - result !== nothing && return success(result, 1) + check_conditions(result) && return success(result, 1) frankestein = nothing if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1) @@ -161,7 +183,7 @@ function term_matcher_constructor(term, acSets) if frankestein !==nothing result = loop(frankestein, bindings, matchers) - result !== nothing && return success(result, 1) + check_conditions(result) && return success(result, 1) end return nothing @@ -170,6 +192,7 @@ function term_matcher_constructor(term, acSets) # if we want to do commutative checks, i.e. call matcher with different order of the arguments elseif acSets!==nothing && operation(term) in [+, *] function commutative_term_matcher(success, data, bindings) + # println("in +* matcher of $term with data $data") !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try @@ -183,64 +206,24 @@ function term_matcher_constructor(term, acSets) candidate = Term{T}(f, @views data_args[inds]) result = loop(candidate, bindings, matchers) - result !== nothing && return success(result,1) + check_conditions(result) && return success(result, 1) end # if car(data) does not subtype to number, it might not be commutative else # call the normal matcher result = loop(car(data), bindings, matchers) - result !== nothing && return success(result, 1) + check_conditions(result) && return success(result, 1) end return nothing end return commutative_term_matcher - # if the operation is sqrt, we have to match also ^(1//2) - elseif operation(term)==sqrt - function sqrt_matcher(success, data, bindings) - !islist(data) && return nothing # if data is not a list, return nothing - data = car(data) - !iscall(data) && return nothing # if first element is not a call, return nothing - - # do the normal matcher - result = loop(data, bindings, matchers) - result !== nothing && return success(result, 1) - - if (operation(data) === ^) && (arguments(data)[2] === 1//2) - T = symtype(arguments(data)[1]) - frankestein = Term{T}(sqrt,[arguments(data)[1]]) - result = loop(frankestein, bindings, matchers) - result !== nothing && return success(result, 1) - end - return nothing - end - return sqrt_matcher - # if the operation is exp, we have to match also ℯ^ - elseif operation(term)==exp - function exp_matcher(success, data, bindings) - !islist(data) && return nothing # if data is not a list, return nothing - data = car(data) - !iscall(data) && return nothing # if first element is not a call, return nothing - - # do the normal matcher - result = loop(data, bindings, matchers) - result !== nothing && return success(result, 1) - - if (operation(data) === ^) && (arguments(data)[1] === ℯ) - T = symtype(arguments(data)[2]) - frankestein = Term{T}(exp,[arguments(data)[2]]) - result = loop(frankestein, bindings, matchers) - result !== nothing && return success(result, 1) - end - return nothing - end - return exp_matcher else function term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing !iscall(car(data)) && return nothing # if first element is not a call, return nothing result = loop(car(data), bindings, matchers) - result !== nothing && return success(result, 1) + check_conditions(result) && return success(result, 1) return nothing end return term_matcher @@ -250,33 +233,47 @@ end # creates a matcher for a term containing a defslot, such as: # (~x + ...complicated pattern...) * ~!y # normal part (can bee a tree) operation defslot part -function defslot_term_matcher_constructor(term, acSets) +function defslot_term_matcher_constructor(term, acSets, condition) a = arguments(term) defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term defslot = a[defslot_index] if length(a) == 2 - other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets) + other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets, condition) else others = [a[i] for i in eachindex(a) if i != defslot_index] T = symtype(term) f = operation(term) - other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets) + other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets, condition) end - normal_matcher = term_matcher_constructor(term, acSets) + normal_matcher = term_matcher_constructor(term, acSets, condition) + + function defslot_term_matcher(success, data, bindings) + # println("in defslotmatcher of $term with data $data") !islist(data) && return nothing # if data is not a list, return nothing # call the normal matcher, with success function foo1 that simply returns the bindings # <--foo1--> result = normal_matcher((b,n) -> b, data, bindings) result !== nothing && return success(result, 1) + # println("no match, trying defslot") # if no match, try to match with a defslot. # checks whether it matches the normal part if yes executes foo2 # foo2: adds the pair (default value name, default value) to the found bindings # <-------------------foo2----------------------------> result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) - result !== nothing && return success(result, 1) - nothing + result === nothing && return nothing + # println("defslot match!") + try + tmp = condition(result) + # tmp==nothing means no conditions are present + if tmp===nothing || tmp + return success(result, 1) + end + catch e + # println("condition failed, continuing") + return success(result, 1) + end end -end +end \ No newline at end of file diff --git a/src/rule.jl b/src/rule.jl index 2c2689ea8..d02dcd8d4 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -22,7 +22,7 @@ makeslot(s::Symbol, keys) = (push!(keys, s); Slot(s)) # for when the slot is an expression, like `~x::predicate` function makeslot(s::Expr, keys) if !(s.head == :(::)) - error("Syntax for specifying a slot is ~x::predicate, where predicate is a boolean function") + error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function") end name = s.args[1] @@ -206,6 +206,7 @@ function (r::Rule)(term) rhs = r.rhs try + # TODO is assoc(bindings, :MATCH, term) necessary? # n == 1 means that exactly one term of the input (term,) was matched success(bindings, n) = n == 1 ? (@timer "RHS" rhs(assoc(bindings, :MATCH, term))) : nothing return r.matcher(success, (term,), EMPTY_IMMUTABLE_DICT) @@ -224,11 +225,11 @@ function rewrite_rhs(expr::Expr) if expr.head == :where rhs = expr.args[1] predicate = expr.args[2] - expr = :($predicate ? $rhs : nothing) + return rhs, predicate end - return expr + return expr, nothing end -rewrite_rhs(expr) = expr +rewrite_rhs(expr) = expr, nothing """ @rule LHS => RHS @@ -369,7 +370,8 @@ of an expression. macro rule(expr) @assert expr.head == :call && expr.args[1] == :(=>) lhs = expr.args[2] - rhs = rewrite_rhs(expr.args[3]) + rhs, condition = rewrite_rhs(expr.args[3]) + keys = Symbol[] lhs_term = makepattern(lhs, keys) unique!(keys) @@ -379,7 +381,7 @@ macro rule(expr) Rule( $(QuoteNode(expr)), lhs_pattern, - matcher(lhs_pattern, permutations), + matcher(lhs_pattern, permutations, __MATCHES__ -> $(makeconsequent(condition))), __MATCHES__ -> $(makeconsequent(rhs)), rule_depth($lhs_term) ) @@ -444,7 +446,7 @@ getdepth(r::ACRule) = getdepth(r.rule) macro acrule(expr) @assert expr.head == :call && expr.args[1] == :(=>) lhs = expr.args[2] - rhs = rewrite_rhs(expr.args[3]) + rhs, condition = rewrite_rhs(expr.args[3]) keys = Symbol[] lhs_term = makepattern(lhs, keys) unique!(keys) @@ -456,7 +458,7 @@ macro acrule(expr) lhs_pattern = $(lhs_term) rule = Rule($(QuoteNode(expr)), lhs_pattern, - matcher(lhs_pattern, permutations), + matcher(lhs_pattern, permutations, __MATCHES__ -> $(makeconsequent(condition))), __MATCHES__ -> $(makeconsequent(rhs)), rule_depth($lhs_term)) ACRule(permutations, rule, $arity) @@ -466,7 +468,7 @@ end macro ordered_acrule(expr) @assert expr.head == :call && expr.args[1] == :(=>) lhs = expr.args[2] - rhs = rewrite_rhs(expr.args[3]) + rhs, condition = rewrite_rhs(expr.args[3]) keys = Symbol[] lhs_term = makepattern(lhs, keys) unique!(keys) @@ -478,7 +480,7 @@ macro ordered_acrule(expr) lhs_pattern = $(lhs_term) rule = Rule($(QuoteNode(expr)), lhs_pattern, - matcher(lhs_pattern, combinations), + matcher(lhs_pattern, combinations, __MATCHES__ -> $(makeconsequent(condition))), __MATCHES__ -> $(makeconsequent(rhs)), rule_depth($lhs_term)) ACRule(combinations, rule, $arity) @@ -577,4 +579,4 @@ macro timerewrite(expr) :(timerewrite(()->$(esc(expr)))) end -Base.@deprecate RuleSet(x) Postwalk(Chain(x)) +Base.@deprecate RuleSet(x) Postwalk(Chain(x)) \ No newline at end of file diff --git a/test/rewrite.jl b/test/rewrite.jl index 0d228af8f..6288928e6 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -143,14 +143,17 @@ end @test r1(sqrt(a)) === (a, 1//2) # uses sqrt_matcher end -@testset "Alternate form of special functions" begin - rsqrt = @rule sqrt(~x) => ~x - @test rsqrt(sqrt(x))===x - @test rsqrt((x)^(1//2))===x - - rexp = @rule exp(~x) => ~x - @test rexp(exp(x)) === x - @test rexp(ℯ^x) === x +@testset "conditions inside rule" begin + r = @rule (~x)^(~m)*(~y)^(~n) => (~x, ~m, ~y, ~n) where (~m)^(~n)==8 + @test r((a^2)*(b^3)) === (a, 2, b, 3) + @test r((b^2)*(a^3)) === (b, 2, a, 3) + + r_defslot = @rule (~x)^(~m)*(~y)^(~!n) => (~x, ~m, ~y, ~n) where (~m)^(~n)==8 + @test r_defslot(y*x^8) === (x, 8, y, 1) + @test r_defslot(x*y^8) === (y, 8, x, 1) + + r_defslot_2 = @rule (~x)^(~!m) => (~x, ~m) where false + @test C2(y)===nothing end using SymbolicUtils: @capture