Skip to content

Commit 1e70235

Browse files
committed
refactor: reduce some redundant allocations
1 parent 5bfb9d7 commit 1e70235

File tree

1 file changed

+135
-46
lines changed

1 file changed

+135
-46
lines changed

src/ArrayNode.jl

Lines changed: 135 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,34 @@ mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}}
4747
free_count::UInt16
4848

4949
function ArrayTree{T,D}(n::Int; array_type::Type{<:AbstractVector}=Vector) where {T,D}
50-
# Create backing arrays of the specified type
51-
degree = array_type{UInt8}(undef, n)
52-
constant = array_type{Bool}(undef, n)
53-
val = array_type{T}(undef, n)
54-
feature = array_type{UInt16}(undef, n)
55-
op = array_type{UInt8}(undef, n)
56-
children = array_type{NTuple{D,UInt16}}(undef, n)
57-
58-
# Create a StructVector from the backing arrays
59-
nodes = StructVector{NodeData{T,D}}((
60-
degree=degree,
61-
constant=constant,
62-
val=val,
63-
feature=feature,
64-
op=op,
65-
children=children,
66-
))
50+
# Create uninitialized StructVector directly
51+
# For custom array types, we'd need to pass them to StructVector somehow
52+
# For now, just use the default
53+
nodes = if array_type === Vector
54+
StructVector{NodeData{T,D}}(undef, n)
55+
else
56+
# For other array types, create backing arrays manually
57+
degree = array_type{UInt8}(undef, n)
58+
constant = array_type{Bool}(undef, n)
59+
val = array_type{T}(undef, n)
60+
feature = array_type{UInt16}(undef, n)
61+
op = array_type{UInt8}(undef, n)
62+
children = array_type{NTuple{D,UInt16}}(undef, n)
63+
StructVector{NodeData{T,D}}((
64+
degree=degree,
65+
constant=constant,
66+
val=val,
67+
feature=feature,
68+
op=op,
69+
children=children,
70+
))
71+
end
6772

6873
S = typeof(nodes)
69-
tree = new{T,D,S}(nodes, UInt16(0), UInt16(0), Vector{UInt16}(undef, n), UInt16(n))
70-
# Initialize free list
71-
for i in 1:n
74+
free_list = Vector{UInt16}(undef, n)
75+
tree = new{T,D,S}(nodes, UInt16(0), UInt16(0), free_list, UInt16(n))
76+
# Initialize free list in-place
77+
@inbounds @simd for i in 1:n
7278
tree.free_list[i] = UInt16(i)
7379
end
7480
return tree
@@ -270,7 +276,9 @@ function ArrayNode{T,D}(
270276
return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx)
271277
end
272278

273-
function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::UInt16) where {T,D}
279+
function copy_subtree!(
280+
dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::UInt16
281+
) where {T,D}
274282
dst_idx = allocate_node!(dst)
275283

276284
@inbounds begin
@@ -314,17 +322,14 @@ end
314322
function unsafe_get_children(n::ArrayNode{T,D,S}) where {T,D,S}
315323
tree = getfield(n, :tree)
316324
idx = getfield(n, :idx)
317-
return ntuple(
318-
i -> begin
319-
child_idx = @inbounds tree.nodes.children[idx][i]
320-
if child_idx == 0
321-
Nullable(true, n)
322-
else
323-
Nullable(false, ArrayNode{T,D,S}(tree, child_idx))
324-
end
325-
end,
326-
Val(D),
327-
)
325+
return ntuple(i -> begin
326+
child_idx = @inbounds tree.nodes.children[idx][i]
327+
if child_idx == 0
328+
Nullable(true, n)
329+
else
330+
Nullable(false, ArrayNode{T,D,S}(tree, child_idx))
331+
end
332+
end, Val(D))
328333
end
329334

330335
function get_children(n::ArrayNode{T,D,S}, ::Val{d}) where {T,D,S,d}
@@ -379,31 +384,115 @@ function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S}
379384
return nothing
380385
end
381386

387+
# Helper to mark nodes as reachable from a given root
388+
function mark_reachable!(
389+
reachable::Vector{Bool}, tree::ArrayTree{T,D}, idx::UInt16
390+
) where {T,D}
391+
if idx == 0 || reachable[idx]
392+
return nothing
393+
end
394+
reachable[idx] = true
395+
degree = @inbounds tree.nodes.degree[idx]
396+
for i in 1:degree
397+
child_idx = @inbounds tree.nodes.children[idx][i]
398+
if child_idx != 0
399+
mark_reachable!(reachable, tree, child_idx)
400+
end
401+
end
402+
end
403+
382404
# Copy
405+
# Note: break_sharing parameter is ignored since ArrayNode doesn't preserve sharing
383406
function copy_node(n::ArrayNode{T,D,S}; break_sharing::Val{BS}=Val(false)) where {T,D,S,BS}
407+
# BS parameter unused - ArrayNode always breaks sharing since each node owns its tree
384408
tree = getfield(n, :tree)
385409
idx = getfield(n, :idx)
410+
n_capacity = length(tree.nodes)
386411

387-
# Count nodes to determine tree size needed
388-
node_count = count_nodes(n)
389-
# Add some buffer space
390-
tree_size = max(32, node_count * 2)
391-
392-
# Determine the array type from the existing tree's nodes
393-
# Default to Vector since that's the most common case
394-
# For other array types, we'd need more sophisticated type extraction
412+
# Create new tree with same capacity
395413
new_tree = if tree.nodes.degree isa Vector
396-
ArrayTree{T,D}(tree_size; array_type=Vector)
414+
ArrayTree{T,D}(n_capacity; array_type=Vector)
397415
else
398-
# For other array types like FixedSizeVector, we just use default
399-
ArrayTree{T,D}(tree_size)
416+
ArrayTree{T,D}(n_capacity)
417+
end
418+
419+
# Direct array copy - works for both full tree and subtree
420+
new_tree.nodes.degree[:] = tree.nodes.degree
421+
new_tree.nodes.constant[:] = tree.nodes.constant
422+
new_tree.nodes.val[:] = tree.nodes.val
423+
new_tree.nodes.feature[:] = tree.nodes.feature
424+
new_tree.nodes.op[:] = tree.nodes.op
425+
new_tree.nodes.children[:] = tree.nodes.children
426+
427+
# Set the root to our copied node
428+
new_tree.root_idx = idx
429+
430+
if idx == tree.root_idx
431+
# Full tree copy - just copy all metadata
432+
new_tree.n_nodes = tree.n_nodes
433+
new_tree.free_count = tree.free_count
434+
new_tree.free_list[:] = tree.free_list
435+
else
436+
# Subtree copy - need to update free list to exclude unreachable nodes
437+
reachable = fill(false, n_capacity)
438+
mark_reachable!(reachable, new_tree, idx)
439+
440+
# Reset free list with unreachable nodes
441+
new_tree.free_count = 0
442+
new_tree.n_nodes = 0
443+
for i in 1:n_capacity
444+
if !reachable[i]
445+
new_tree.free_count += 1
446+
new_tree.free_list[new_tree.free_count] = UInt16(i)
447+
else
448+
new_tree.n_nodes += 1
449+
end
450+
end
400451
end
401-
new_idx = copy_subtree!(new_tree, tree, idx)
402-
new_tree.root_idx = new_idx
403452

404-
return ArrayNode{T,D,S}(new_tree, new_idx)
453+
return ArrayNode{T,D,S}(new_tree, new_tree.root_idx)
405454
end
406455

407456
Base.copy(n::ArrayNode) = copy_node(n)
408457

458+
# tree_mapreduce implementation
459+
function tree_mapreduce(
460+
f::F,
461+
op::G,
462+
tree::ArrayNode{T,D,S},
463+
result_type::Type{RT}=Undefined;
464+
f_on_shared::H=(result, is_shared) -> result,
465+
break_sharing::Val{BS}=Val(false),
466+
) where {F<:Function,G<:Function,H<:Function,T,D,S,RT,BS}
467+
return tree_mapreduce(f, f, op, tree, result_type; f_on_shared, break_sharing)
468+
end
469+
470+
function tree_mapreduce(
471+
f_leaf::F1,
472+
f_branch::F2,
473+
op::G,
474+
tree::ArrayNode{T,D,S},
475+
result_type::Type{RT}=Undefined;
476+
f_on_shared::H=(result, is_shared) -> result,
477+
break_sharing::Val{BS}=Val(false),
478+
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,T,D,S,RT,BS}
479+
# ArrayNode doesn't preserve sharing, so we can use simple recursion
480+
if tree.degree == 0
481+
return f_leaf(tree)
482+
else
483+
# Apply to children
484+
degree = tree.degree
485+
children_results = ntuple(Val(Int(degree))) do i
486+
child = get_children(tree, Val(degree))[i]
487+
tree_mapreduce(
488+
f_leaf, f_branch, op, child, result_type; f_on_shared, break_sharing
489+
)
490+
end
491+
492+
# Reduce children results
493+
self_result = f_branch(tree)
494+
return op(self_result, children_results...)
495+
end
496+
end
497+
409498
end # module

0 commit comments

Comments
 (0)