Skip to content

Commit 74c8dc1

Browse files
committed
refactor: fix some performance regressions in upstream SR.jl
1 parent 3a42bce commit 74c8dc1

File tree

2 files changed

+64
-55
lines changed

2 files changed

+64
-55
lines changed

src/Expression.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,23 +274,23 @@ copy_node(ex::AbstractExpression; kws...) = copy(ex)
274274
count_nodes(ex::AbstractExpression; kws...) = count_nodes(get_tree(ex); kws...)
275275

276276
function tree_mapreduce(
277-
f::Function,
278-
op::Function,
277+
f::F,
278+
op::G,
279279
ex::AbstractExpression,
280-
result_type::Type=Undefined;
280+
result_type::Type{RT}=Undefined;
281281
kws...,
282-
)
283-
return tree_mapreduce(f, op, get_tree(ex), result_type; kws...)
282+
) where {F<:Function,G<:Function,RT}
283+
return tree_mapreduce(f, op, get_tree(ex), RT; kws...)
284284
end
285285
function tree_mapreduce(
286-
f_leaf::Function,
287-
f_branch::Function,
288-
op::Function,
286+
f_leaf::F,
287+
f_branch::G,
288+
op::H,
289289
ex::AbstractExpression,
290-
result_type::Type=Undefined;
290+
result_type::Type{RT}=Undefined;
291291
kws...,
292-
)
293-
return tree_mapreduce(f_leaf, f_branch, op, get_tree(ex), result_type; kws...)
292+
) where {F<:Function,G<:Function,H<:Function,RT}
293+
return tree_mapreduce(f_leaf, f_branch, op, get_tree(ex), RT; kws...)
294294
end
295295

296296
count_constant_nodes(ex::AbstractExpression) = count_constant_nodes(get_tree(ex))

src/base.jl

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ function tree_mapreduce(
8181
tree::AbstractNode,
8282
result_type::Type{RT}=Undefined;
8383
f_on_shared::H=(result, is_shared) -> result,
84-
break_sharing=Val(false),
85-
) where {RT,F<:Function,G<:Function,H<:Function}
86-
return tree_mapreduce(f, f, op, tree, RT; f_on_shared, break_sharing)
84+
break_sharing::Val{BS}=Val(false),
85+
) where {RT,F<:Function,G<:Function,H<:Function,BS}
86+
return tree_mapreduce(f, f, op, tree, RT; f_on_shared, break_sharing=Val(BS))
8787
end
8888
function tree_mapreduce(
8989
f_leaf::F1,
@@ -92,8 +92,8 @@ function tree_mapreduce(
9292
tree::AbstractNode,
9393
result_type::Type{RT}=Undefined;
9494
f_on_shared::H=(result, is_shared) -> result,
95-
break_sharing::Val=Val(false),
96-
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT}
95+
break_sharing::Val{BS}=Val(false),
96+
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT,BS}
9797

9898
# Trick taken from here:
9999
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
@@ -108,7 +108,7 @@ function tree_mapreduce(
108108
end
109109
end
110110

111-
sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false)
111+
sharing = preserve_sharing(typeof(tree)) && !BS
112112

113113
RT == Undefined &&
114114
sharing &&
@@ -222,14 +222,14 @@ end
222222
223223
Count the number of nodes in the tree.
224224
"""
225-
function count_nodes(tree::AbstractNode; break_sharing=Val(false))
225+
function count_nodes(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
226226
return tree_mapreduce(
227227
_ -> 1,
228228
+,
229229
tree,
230230
Int64;
231231
f_on_shared=(c, is_shared) -> is_shared ? 0 : c,
232-
break_sharing,
232+
break_sharing=Val(BS),
233233
)
234234
end
235235

@@ -239,10 +239,14 @@ end
239239
Apply a function to each node in a tree without returning the results.
240240
"""
241241
function foreach(
242-
f::F, tree::AbstractNode; break_sharing::Val=Val(false)
243-
) where {F<:Function}
242+
f::F, tree::AbstractNode; break_sharing::Val{BS}=Val(false)
243+
) where {F<:Function,BS}
244244
tree_mapreduce(
245-
t -> (@inline(f(t)); nothing), Returns(nothing), tree, Nothing; break_sharing
245+
t -> (@inline(f(t)); nothing),
246+
Returns(nothing),
247+
tree,
248+
Nothing;
249+
break_sharing=Val(BS),
246250
)
247251
return nothing
248252
end
@@ -260,10 +264,10 @@ function filter_map(
260264
map_fnc::G,
261265
tree::AbstractNode,
262266
result_type::Type{GT};
263-
break_sharing::Val=Val(false),
264-
) where {F<:Function,G<:Function,GT}
265-
stack = Array{GT}(undef, count(filter_fnc, tree; init=0, break_sharing))
266-
filter_map!(filter_fnc, map_fnc, stack, tree; break_sharing)
267+
break_sharing::Val{BS}=Val(false),
268+
) where {F<:Function,G<:Function,GT,BS}
269+
stack = Array{GT}(undef, count(filter_fnc, tree; init=0, break_sharing=Val(BS)))
270+
filter_map!(filter_fnc, map_fnc, stack, tree; break_sharing=Val(BS))
267271
return stack::Vector{GT}
268272
end
269273

@@ -277,10 +281,10 @@ function filter_map!(
277281
map_fnc::G,
278282
destination::Vector{GT},
279283
tree::AbstractNode;
280-
break_sharing::Val=Val(false),
281-
) where {GT,F<:Function,G<:Function}
284+
break_sharing::Val{BS}=Val(false),
285+
) where {GT,F<:Function,G<:Function,BS}
282286
pointer = Ref(0)
283-
foreach(tree; break_sharing) do t
287+
foreach(tree; break_sharing=Val(BS)) do t
284288
if @inline(filter_fnc(t))
285289
map_result = @inline(map_fnc(t))::GT
286290
@inbounds destination[pointer.x += 1] = map_result
@@ -294,55 +298,60 @@ end
294298
295299
Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
296300
"""
297-
function filter(f::F, tree::AbstractNode; break_sharing::Val=Val(false)) where {F<:Function}
298-
return filter_map(f, identity, tree, typeof(tree); break_sharing)
301+
function filter(
302+
f::F, tree::AbstractNode; break_sharing::Val{BS}=Val(false)
303+
) where {F<:Function,BS}
304+
return filter_map(f, identity, tree, typeof(tree); break_sharing=Val(BS))
299305
end
300306

301307
"""
302308
collect(tree::AbstractNode; break_sharing::Val=Val(false))
303309
304310
Collect all nodes in a tree into a flat array in depth-first order.
305311
"""
306-
function collect(tree::AbstractNode; break_sharing::Val=Val(false))
307-
return filter(Returns(true), tree; break_sharing)
312+
function collect(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
313+
return filter(Returns(true), tree; break_sharing=Val(BS))
308314
end
309315

310316
"""
311-
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}
317+
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val{BS}=Val(false)) where {F<:Function,RT,BS}
312318
313319
Map a function over a tree and return a flat array of the results in depth-first order.
314320
Pre-specifying the `result_type` of the function can be used to avoid extra allocations.
315321
"""
316322
function map(
317-
f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)
318-
) where {F<:Function,RT}
323+
f::F,
324+
tree::AbstractNode,
325+
result_type::Type{RT}=Nothing;
326+
break_sharing::Val{BS}=Val(false),
327+
) where {F<:Function,RT,BS}
319328
if RT == Nothing
320-
return map(f, collect(tree; break_sharing))
329+
return map(f, collect(tree; break_sharing=Val(BS)))
321330
else
322-
return filter_map(Returns(true), f, tree, result_type; break_sharing)
331+
return filter_map(Returns(true), f, tree, result_type; break_sharing=Val(BS))
323332
end
324333
end
325334

326335
"""
327-
count(f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)) where {F<:Function}
336+
count(f::F, tree::AbstractNode; init=0, break_sharing::Val{BS}=Val(false)) where {F<:Function,BS}
328337
329338
Count the number of nodes in a tree for which the function returns `true`.
330339
"""
331340
function count(
332-
f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)
333-
) where {F<:Function}
341+
f::F, tree::AbstractNode; init=0, break_sharing::Val{BS}=Val(false)
342+
) where {F<:Function,BS}
334343
return tree_mapreduce(
335344
t -> @inline(f(t)) ? 1 : 0,
336345
+,
337346
tree,
338347
Int64;
339348
f_on_shared=(c, is_shared) -> is_shared ? 0 : c,
340-
break_sharing,
349+
break_sharing=Val(BS),
341350
) + init
342351
end
343352

344353
"""
345-
sum(f::Function, tree::AbstractNode; result_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
354+
sum(f::Function, tree::AbstractNode; result_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val{BS}=Val(false)) where {F<:Function,BS}
346355
347356
Sum the results of a function over a tree. For graphs with shared nodes
348357
such as [`GraphNode`](@ref), the function `f_on_shared` is called on the result
@@ -386,7 +395,7 @@ function mapreduce(
386395
"Must specify `result_type` as a keyword argument to `mapreduce` if `preserve_sharing` is true."
387396
)
388397
end
389-
return tree_mapreduce(f, op, tree, RT; f_on_shared, break_sharing)
398+
return tree_mapreduce(f, op, tree, RT; f_on_shared, break_sharing=Val(BS))
390399
end
391400

392401
isempty(::AbstractNode) = false
@@ -396,8 +405,8 @@ end
396405
@unstable iterate(::AbstractNode, stack) =
397406
isempty(stack) ? nothing : (popfirst!(stack), stack)
398407
in(item, tree::AbstractNode) = any(t -> t == item, tree)
399-
function length(tree::AbstractNode; break_sharing::Val=Val(false))
400-
return count_nodes(tree; break_sharing)
408+
function length(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
409+
return count_nodes(tree; break_sharing=Val(BS))
401410
end
402411

403412
"""
@@ -407,8 +416,8 @@ Compute a hash of a tree. This will compute a hash differently
407416
if nodes are shared in a tree. This is ignored if `break_sharing` is set to `Val(true)`.
408417
"""
409418
function hash(
410-
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val=Val(false)
411-
) where {T}
419+
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val{BS}=Val(false)
420+
) where {T,BS}
412421
return tree_mapreduce(
413422
t -> leaf_hash(h, t),
414423
identity,
@@ -417,7 +426,7 @@ function hash(
417426
UInt;
418427
f_on_shared=(cur_hash, is_shared) ->
419428
is_shared ? hash((:shared, cur_hash), h) : cur_hash,
420-
break_sharing,
429+
break_sharing=Val(BS),
421430
)
422431
end
423432
function leaf_hash(h::UInt, t::AbstractExpressionNode)
@@ -428,17 +437,17 @@ function branch_hash(h::UInt, t::AbstractExpressionNode, children::Vararg{Any,M}
428437
end
429438

430439
"""
431-
copy_node(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
440+
copy_node(tree::AbstractExpressionNode; break_sharing::Val{BS}=Val(false)) where {BS}
432441
433442
Copy a node, recursively copying all children nodes.
434443
This is more efficient than the built-in copy.
435444
436445
If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
437446
"""
438447
function copy_node(
439-
tree::N; break_sharing::Val=Val(false)
440-
) where {T,N<:AbstractExpressionNode{T}}
441-
return tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing)
448+
tree::N; break_sharing::Val{BS}=Val(false)
449+
) where {T,N<:AbstractExpressionNode{T},BS}
450+
return tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing=Val(BS))
442451
end
443452
function leaf_copy(t::N) where {T,N<:AbstractExpressionNode{T}}
444453
if t.constant
@@ -459,8 +468,8 @@ This is more efficient than the built-in copy.
459468
460469
If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
461470
"""
462-
function copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
463-
return copy_node(tree; break_sharing)
471+
function copy(tree::AbstractExpressionNode; break_sharing::Val{BS}=Val(false)) where {BS}
472+
return copy_node(tree; break_sharing=Val(BS))
464473
end
465474

466475
"""

0 commit comments

Comments
 (0)