Skip to content

Commit cb725ef

Browse files
committed
now ...^(1//2) matches in the rule with sqrt, and ℯ^... matches in the rule with exp
1 parent da99287 commit cb725ef

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

src/matchers.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,46 @@ function term_matcher_constructor(term, acSets)
194194
return nothing
195195
end
196196
return commutative_term_matcher
197+
# if the operation is sqrt, we have to match also ^(1//2)
198+
elseif operation(term)==sqrt
199+
function sqrt_matcher(success, data, bindings)
200+
!islist(data) && return nothing # if data is not a list, return nothing
201+
data = car(data)
202+
!iscall(data) && return nothing # if first element is not a call, return nothing
203+
204+
# do the normal matcher
205+
result = loop(data, bindings, matchers)
206+
result !== nothing && return success(result, 1)
207+
208+
if (operation(data) === ^) && (arguments(data)[2] === 1//2)
209+
T = symtype(arguments(data)[1])
210+
frankestein = Term{T}(sqrt,[arguments(data)[1]])
211+
result = loop(frankestein, bindings, matchers)
212+
result !== nothing && return success(result, 1)
213+
end
214+
return nothing
215+
end
216+
return sqrt_matcher
217+
# if the operation is exp, we have to match also ℯ^
218+
elseif operation(term)==exp
219+
function exp_matcher(success, data, bindings)
220+
!islist(data) && return nothing # if data is not a list, return nothing
221+
data = car(data)
222+
!iscall(data) && return nothing # if first element is not a call, return nothing
223+
224+
# do the normal matcher
225+
result = loop(data, bindings, matchers)
226+
result !== nothing && return success(result, 1)
227+
228+
if (operation(data) === ^) && (arguments(data)[1] === ℯ)
229+
T = symtype(arguments(data)[2])
230+
frankestein = Term{T}(exp,[arguments(data)[2]])
231+
result = loop(frankestein, bindings, matchers)
232+
result !== nothing && return success(result, 1)
233+
end
234+
return nothing
235+
end
236+
return exp_matcher
197237
else
198238
function term_matcher(success, data, bindings)
199239
!islist(data) && return nothing # if data is not a list, return nothing

test/rewrite.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ end
113113
@test r_predicate(x+2) === (x, 2)
114114
@test r_predicate(x+2.0) !== (x, 2.0)
115115
# Note: r_predicate(x+2.0) doesnt return nothing, but (x+2.0, 0)
116-
# becasue of the defslot
116+
# because of the defslot
117117
end
118118

119119
@testset "power matcher with negative exponent" begin
@@ -159,6 +159,16 @@ end
159159
@test r1(sqrt(a)) === (a, 1//2) # uses sqrt_matcher
160160
end
161161

162+
@testset "Alternate form of special functions" begin
163+
rsqrt = @rule sqrt(~x) => ~x
164+
@test rsqrt(sqrt(x))===x
165+
@test rsqrt((x)^(1//2))===x
166+
167+
rexp = @rule exp(~x) => ~x
168+
@test rexp(exp(x)) === x
169+
@test rexp(ℯ^x) === x
170+
end
171+
162172
using SymbolicUtils: @capture
163173

164174
@testset "Capture form" begin

0 commit comments

Comments
 (0)