Skip to content

Commit 1feb709

Browse files
Merge pull request #779 from Bumblebee00/speed_up_rules
sped up matching process in the commutative_term_matcher
2 parents fa6ed8d + a27d289 commit 1feb709

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ AbstractTrees = "0.4"
4141
ArrayInterface = "7.8"
4242
Bijections = "0.1.2, 0.2"
4343
ChainRulesCore = "1"
44-
Combinatorics = "1.0"
44+
Combinatorics = "1 - 1.0.2"
4545
ConstructionBase = "1.5.7"
4646
DataStructures = "0.18, 0.19"
4747
DocStringExtensions = "0.8, 0.9"

src/matchers.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,31 @@ function term_matcher_constructor(term, acSets)
169169
return pow_term_matcher
170170
# if we want to do commutative checks, i.e. call matcher with different order of the arguments
171171
elseif acSets!==nothing && operation(term) in [+, *]
172+
has_segment = any([isa(a,Segment) for a in arguments(term)])
172173
function commutative_term_matcher(success, data, bindings)
173174
!islist(data) && return nothing # if data is not a list, return nothing
174-
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
175-
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try
175+
data = car(data)
176+
!iscall(data) && return nothing # if first element is not a call, return nothing
177+
operation(term) !== operation(data) && return nothing # if the operation of data is not the correct one, don't even try
178+
data_args = arguments(data)
179+
# if the number of arguments is different, and the rule doesnt have a segment, return nothing
180+
!has_segment && length(matchers)-1 !== length(data_args) && return nothing
181+
176182

177-
T = symtype(car(data))
178-
if T <: Number
179-
f = operation(car(data))
180-
data_args = arguments(car(data))
183+
T = symtype(data)
184+
if T <: Number && length(data_args)<COMM_CHECKS_LIMIT[]
185+
f = operation(data)
181186

182187
for inds in acSets(eachindex(data_args), length(data_args))
183188
candidate = Term{T}(f, @views data_args[inds])
184189

185190
result = loop(candidate, bindings, matchers)
186191
result !== nothing && return success(result,1)
187192
end
188-
# if car(data) does not subtype to number, it might not be commutative
193+
# if data does not subtype to number, it might not be commutative
189194
else
190195
# call the normal matcher
191-
result = loop(car(data), bindings, matchers)
196+
result = loop(data, bindings, matchers)
192197
result !== nothing && return success(result, 1)
193198
end
194199
return nothing

src/rule.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
@inline alwaystrue(x) = true
3+
const COMM_CHECKS_LIMIT = Ref(10)
34

45
# Matcher patterns with Slot, DefSlot and Segment
56

0 commit comments

Comments
 (0)