@@ -47,28 +47,34 @@ mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}}
4747 free_count:: UInt16
4848
4949 function ArrayTree {T,D} (n:: Int ; array_type:: Type{<:AbstractVector} = Vector) where {T,D}
50- # Create backing arrays of the specified type
51- degree = array_type {UInt8} (undef, n)
52- constant = array_type {Bool} (undef, n)
53- val = array_type {T} (undef, n)
54- feature = array_type {UInt16} (undef, n)
55- op = array_type {UInt8} (undef, n)
56- children = array_type {NTuple{D,UInt16}} (undef, n)
57-
58- # Create a StructVector from the backing arrays
59- nodes = StructVector {NodeData{T,D}} ((
60- degree= degree,
61- constant= constant,
62- val= val,
63- feature= feature,
64- op= op,
65- children= children,
66- ))
50+ # Create uninitialized StructVector directly
51+ # For custom array types, we'd need to pass them to StructVector somehow
52+ # For now, just use the default
53+ nodes = if array_type === Vector
54+ StructVector {NodeData{T,D}} (undef, n)
55+ else
56+ # For other array types, create backing arrays manually
57+ degree = array_type {UInt8} (undef, n)
58+ constant = array_type {Bool} (undef, n)
59+ val = array_type {T} (undef, n)
60+ feature = array_type {UInt16} (undef, n)
61+ op = array_type {UInt8} (undef, n)
62+ children = array_type {NTuple{D,UInt16}} (undef, n)
63+ StructVector {NodeData{T,D}} ((
64+ degree= degree,
65+ constant= constant,
66+ val= val,
67+ feature= feature,
68+ op= op,
69+ children= children,
70+ ))
71+ end
6772
6873 S = typeof (nodes)
69- tree = new {T,D,S} (nodes, UInt16 (0 ), UInt16 (0 ), Vector {UInt16} (undef, n), UInt16 (n))
70- # Initialize free list
71- for i in 1 : n
74+ free_list = Vector {UInt16} (undef, n)
75+ tree = new {T,D,S} (nodes, UInt16 (0 ), UInt16 (0 ), free_list, UInt16 (n))
76+ # Initialize free list in-place
77+ @inbounds @simd for i in 1 : n
7278 tree. free_list[i] = UInt16 (i)
7379 end
7480 return tree
@@ -270,7 +276,9 @@ function ArrayNode{T,D}(
270276 return ArrayNode {T,D,typeof(tree.nodes)} (tree, idx)
271277end
272278
273- function copy_subtree! (dst:: ArrayTree{T,D} , src:: ArrayTree{T,D} , src_idx:: UInt16 ) where {T,D}
279+ function copy_subtree! (
280+ dst:: ArrayTree{T,D} , src:: ArrayTree{T,D} , src_idx:: UInt16
281+ ) where {T,D}
274282 dst_idx = allocate_node! (dst)
275283
276284 @inbounds begin
@@ -314,17 +322,14 @@ end
314322function unsafe_get_children (n:: ArrayNode{T,D,S} ) where {T,D,S}
315323 tree = getfield (n, :tree )
316324 idx = getfield (n, :idx )
317- return ntuple (
318- i -> begin
319- child_idx = @inbounds tree. nodes. children[idx][i]
320- if child_idx == 0
321- Nullable (true , n)
322- else
323- Nullable (false , ArrayNode {T,D,S} (tree, child_idx))
324- end
325- end ,
326- Val (D),
327- )
325+ return ntuple (i -> begin
326+ child_idx = @inbounds tree. nodes. children[idx][i]
327+ if child_idx == 0
328+ Nullable (true , n)
329+ else
330+ Nullable (false , ArrayNode {T,D,S} (tree, child_idx))
331+ end
332+ end , Val (D))
328333end
329334
330335function get_children (n:: ArrayNode{T,D,S} , :: Val{d} ) where {T,D,S,d}
@@ -379,31 +384,115 @@ function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S}
379384 return nothing
380385end
381386
387+ # Helper to mark nodes as reachable from a given root
388+ function mark_reachable! (
389+ reachable:: Vector{Bool} , tree:: ArrayTree{T,D} , idx:: UInt16
390+ ) where {T,D}
391+ if idx == 0 || reachable[idx]
392+ return nothing
393+ end
394+ reachable[idx] = true
395+ degree = @inbounds tree. nodes. degree[idx]
396+ for i in 1 : degree
397+ child_idx = @inbounds tree. nodes. children[idx][i]
398+ if child_idx != 0
399+ mark_reachable! (reachable, tree, child_idx)
400+ end
401+ end
402+ end
403+
382404# Copy
405+ # Note: break_sharing parameter is ignored since ArrayNode doesn't preserve sharing
383406function copy_node (n:: ArrayNode{T,D,S} ; break_sharing:: Val{BS} = Val (false )) where {T,D,S,BS}
407+ # BS parameter unused - ArrayNode always breaks sharing since each node owns its tree
384408 tree = getfield (n, :tree )
385409 idx = getfield (n, :idx )
410+ n_capacity = length (tree. nodes)
386411
387- # Count nodes to determine tree size needed
388- node_count = count_nodes (n)
389- # Add some buffer space
390- tree_size = max (32 , node_count * 2 )
391-
392- # Determine the array type from the existing tree's nodes
393- # Default to Vector since that's the most common case
394- # For other array types, we'd need more sophisticated type extraction
412+ # Create new tree with same capacity
395413 new_tree = if tree. nodes. degree isa Vector
396- ArrayTree {T,D} (tree_size ; array_type= Vector)
414+ ArrayTree {T,D} (n_capacity ; array_type= Vector)
397415 else
398- # For other array types like FixedSizeVector, we just use default
399- ArrayTree {T,D} (tree_size)
416+ ArrayTree {T,D} (n_capacity)
417+ end
418+
419+ # Direct array copy - works for both full tree and subtree
420+ new_tree. nodes. degree[:] = tree. nodes. degree
421+ new_tree. nodes. constant[:] = tree. nodes. constant
422+ new_tree. nodes. val[:] = tree. nodes. val
423+ new_tree. nodes. feature[:] = tree. nodes. feature
424+ new_tree. nodes. op[:] = tree. nodes. op
425+ new_tree. nodes. children[:] = tree. nodes. children
426+
427+ # Set the root to our copied node
428+ new_tree. root_idx = idx
429+
430+ if idx == tree. root_idx
431+ # Full tree copy - just copy all metadata
432+ new_tree. n_nodes = tree. n_nodes
433+ new_tree. free_count = tree. free_count
434+ new_tree. free_list[:] = tree. free_list
435+ else
436+ # Subtree copy - need to update free list to exclude unreachable nodes
437+ reachable = fill (false , n_capacity)
438+ mark_reachable! (reachable, new_tree, idx)
439+
440+ # Reset free list with unreachable nodes
441+ new_tree. free_count = 0
442+ new_tree. n_nodes = 0
443+ for i in 1 : n_capacity
444+ if ! reachable[i]
445+ new_tree. free_count += 1
446+ new_tree. free_list[new_tree. free_count] = UInt16 (i)
447+ else
448+ new_tree. n_nodes += 1
449+ end
450+ end
400451 end
401- new_idx = copy_subtree! (new_tree, tree, idx)
402- new_tree. root_idx = new_idx
403452
404- return ArrayNode {T,D,S} (new_tree, new_idx )
453+ return ArrayNode {T,D,S} (new_tree, new_tree . root_idx )
405454end
406455
407456Base. copy (n:: ArrayNode ) = copy_node (n)
408457
458+ # tree_mapreduce implementation
459+ function tree_mapreduce (
460+ f:: F ,
461+ op:: G ,
462+ tree:: ArrayNode{T,D,S} ,
463+ result_type:: Type{RT} = Undefined;
464+ f_on_shared:: H = (result, is_shared) -> result,
465+ break_sharing:: Val{BS} = Val (false ),
466+ ) where {F<: Function ,G<: Function ,H<: Function ,T,D,S,RT,BS}
467+ return tree_mapreduce (f, f, op, tree, result_type; f_on_shared, break_sharing)
468+ end
469+
470+ function tree_mapreduce (
471+ f_leaf:: F1 ,
472+ f_branch:: F2 ,
473+ op:: G ,
474+ tree:: ArrayNode{T,D,S} ,
475+ result_type:: Type{RT} = Undefined;
476+ f_on_shared:: H = (result, is_shared) -> result,
477+ break_sharing:: Val{BS} = Val (false ),
478+ ) where {F1<: Function ,F2<: Function ,G<: Function ,H<: Function ,T,D,S,RT,BS}
479+ # ArrayNode doesn't preserve sharing, so we can use simple recursion
480+ if tree. degree == 0
481+ return f_leaf (tree)
482+ else
483+ # Apply to children
484+ degree = tree. degree
485+ children_results = ntuple (Val (Int (degree))) do i
486+ child = get_children (tree, Val (degree))[i]
487+ tree_mapreduce (
488+ f_leaf, f_branch, op, child, result_type; f_on_shared, break_sharing
489+ )
490+ end
491+
492+ # Reduce children results
493+ self_result = f_branch (tree)
494+ return op (self_result, children_results... )
495+ end
496+ end
497+
409498end # module
0 commit comments