Skip to content

Commit 6044bc3

Browse files
committed
refactor: use UInt16 for indices in array node
1 parent e8cb83b commit 6044bc3

File tree

1 file changed

+42
-56
lines changed

1 file changed

+42
-56
lines changed

src/ArrayNode.jl

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,22 @@ struct NodeData{T,D}
2929
val::T
3030
feature::UInt16
3131
op::UInt8
32-
children::NTuple{D,Int8}
32+
children::NTuple{D,UInt16}
3333
end
3434

3535
# Constructor for empty node
3636
function NodeData{T,D}() where {T,D}
3737
return NodeData{T,D}(
38-
UInt8(0), true, zero(T), UInt16(0), UInt8(0), ntuple(_ -> Int8(-1), Val(D))
38+
UInt8(0), true, zero(T), UInt16(0), UInt8(0), ntuple(_ -> UInt16(0), Val(D))
3939
)
4040
end
4141

4242
mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}}
4343
const nodes::S
44-
root_idx::Int8
45-
n_nodes::Int8
46-
const free_list::Vector{Int8}
47-
free_count::Int8
44+
root_idx::UInt16
45+
n_nodes::UInt16
46+
const free_list::Vector{UInt16}
47+
free_count::UInt16
4848

4949
function ArrayTree{T,D}(n::Int; array_type::Type{<:AbstractVector}=Vector) where {T,D}
5050
# Create backing arrays of the specified type
@@ -53,7 +53,7 @@ mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}}
5353
val = array_type{T}(undef, n)
5454
feature = array_type{UInt16}(undef, n)
5555
op = array_type{UInt8}(undef, n)
56-
children = array_type{NTuple{D,Int8}}(undef, n)
56+
children = array_type{NTuple{D,UInt16}}(undef, n)
5757

5858
# Create a StructVector from the backing arrays
5959
nodes = StructVector{NodeData{T,D}}((
@@ -65,29 +65,19 @@ mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}}
6565
children=children,
6666
))
6767

68-
# Initialize all nodes to default values
69-
for i in 1:n
70-
nodes.degree[i] = UInt8(0)
71-
nodes.constant[i] = true
72-
nodes.val[i] = zero(T)
73-
nodes.feature[i] = UInt16(0)
74-
nodes.op[i] = UInt8(0)
75-
nodes.children[i] = ntuple(_ -> Int8(-1), Val(D))
76-
end
77-
7868
S = typeof(nodes)
79-
tree = new{T,D,S}(nodes, Int8(0), Int8(0), Vector{Int8}(undef, n), Int8(n))
69+
tree = new{T,D,S}(nodes, UInt16(0), UInt16(0), Vector{UInt16}(undef, n), UInt16(n))
8070
# Initialize free list
8171
for i in 1:n
82-
tree.free_list[i] = Int8(i)
72+
tree.free_list[i] = UInt16(i)
8373
end
8474
return tree
8575
end
8676
end
8777

8878
struct ArrayNode{T,D,S} <: AbstractExpressionNode{T,D}
8979
tree::ArrayTree{T,D,S}
90-
idx::Int8
80+
idx::UInt16
9181
end
9282

9383
function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S}
@@ -113,18 +103,18 @@ function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S}
113103
# Return tuple of child ArrayNodes wrapped in Nullable
114104
return ntuple(Val(D)) do i
115105
child_idx = @inbounds nodes.children[idx][i]
116-
if child_idx < 0
106+
if child_idx == 0
117107
Nullable(true, n) # Poison node
118108
else
119109
Nullable(false, ArrayNode{T,D,S}(tree, child_idx))
120110
end
121111
end
122112
elseif k == :l # Left child for compatibility
123113
child_idx = @inbounds nodes.children[idx][1]
124-
return child_idx < 0 ? error("No left child") : ArrayNode{T,D,S}(tree, child_idx)
114+
return child_idx == 0 ? error("No left child") : ArrayNode{T,D,S}(tree, child_idx)
125115
elseif k == :r # Right child for compatibility
126116
child_idx = @inbounds nodes.children[idx][2]
127-
return child_idx < 0 ? error("No right child") : ArrayNode{T,D,S}(tree, child_idx)
117+
return child_idx == 0 ? error("No right child") : ArrayNode{T,D,S}(tree, child_idx)
128118
else
129119
error("Unknown field $k")
130120
end
@@ -168,7 +158,7 @@ function allocate_node!(tree::ArrayTree)
168158
return idx
169159
end
170160

171-
function free_node!(tree::ArrayTree, idx::Int8)
161+
function free_node!(tree::ArrayTree, idx::UInt16)
172162
tree.free_count += 1
173163
tree.free_list[tree.free_count] = idx
174164
return tree.n_nodes -= 1
@@ -228,7 +218,6 @@ function ArrayNode{T,D}(
228218
end
229219

230220
if !isnothing(op)
231-
# DEBUG: op=$op, l is nothing? $(isnothing(l)), r is nothing? $(isnothing(r))
232221
_children = if !isnothing(l) && isnothing(r)
233222
(l,)
234223
elseif !isnothing(l) && !isnothing(r)
@@ -238,7 +227,6 @@ function ArrayNode{T,D}(
238227
end
239228

240229
if !isnothing(_children)
241-
# DEBUG: Building node with children, length=length(_children)
242230
degree = length(_children)
243231
tree.nodes.degree[idx] = degree
244232
tree.nodes.op[idx] = op
@@ -248,7 +236,6 @@ function ArrayNode{T,D}(
248236
i -> begin
249237
if i <= length(_children)
250238
child = _children[i]
251-
# DEBUG: Processing child $i, isa ArrayNode? isa(child, ArrayNode)
252239
if isa(child, ArrayNode)
253240
child_tree = getfield(child, :tree)
254241
child_idx = getfield(child, :idx)
@@ -258,21 +245,13 @@ function ArrayNode{T,D}(
258245
else
259246
# Different tree - copy
260247
new_idx = copy_subtree!(tree, child_tree, child_idx)
261-
# DEBUG
262-
# println("Copied child from idx $child_idx to new idx $new_idx")
263-
# println(" Original: constant=", child_tree.nodes.constant[child_idx],
264-
# child_tree.nodes.constant[child_idx] ? ", val=" : ", feature=",
265-
# child_tree.nodes.constant[child_idx] ? child_tree.nodes.val[child_idx] : child_tree.nodes.feature[child_idx])
266-
# println(" Copied: constant=", tree.nodes.constant[new_idx],
267-
# tree.nodes.constant[new_idx] ? ", val=" : ", feature=",
268-
# tree.nodes.constant[new_idx] ? tree.nodes.val[new_idx] : tree.nodes.feature[new_idx])
269248
new_idx
270249
end
271250
else
272-
Int8(-1)
251+
UInt16(0)
273252
end
274253
else
275-
Int8(-1)
254+
UInt16(0)
276255
end
277256
end,
278257
Val(D),
@@ -290,7 +269,7 @@ function ArrayNode{T,D}(
290269
return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx)
291270
end
292271

293-
function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::Int8) where {T,D}
272+
function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::UInt16) where {T,D}
294273
dst_idx = allocate_node!(dst)
295274

296275
@inbounds begin
@@ -306,13 +285,13 @@ function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::Int8)
306285
i -> begin
307286
if i <= degree
308287
child_idx = @inbounds src.nodes.children[src_idx][i]
309-
if child_idx >= 0
288+
if child_idx > 0
310289
copy_subtree!(dst, src, child_idx)
311290
else
312-
Int8(-1)
291+
UInt16(0)
313292
end
314293
else
315-
Int8(-1)
294+
UInt16(0)
316295
end
317296
end, Val(D)
318297
)
@@ -347,10 +326,10 @@ function unsafe_get_children(n::ArrayNode{T,D,S}) where {T,D,S}
347326
return ntuple(
348327
i -> begin
349328
child_idx = @inbounds tree.nodes.children[idx][i]
350-
if child_idx < 0
329+
if child_idx == 0
351330
Nullable(true, n)
352331
else
353-
Nullable(false, ArrayNode{T,D,typeof(tree.nodes)}(tree, child_idx))
332+
Nullable(false, ArrayNode{T,D,S}(tree, child_idx))
354333
end
355334
end,
356335
Val(D),
@@ -362,7 +341,7 @@ function get_children(n::ArrayNode{T,D,S}, ::Val{d}) where {T,D,S,d}
362341
idx = getfield(n, :idx)
363342
return ntuple(i -> begin
364343
child_idx = @inbounds tree.nodes.children[idx][i]
365-
ArrayNode{T,D,typeof(tree.nodes)}(tree, child_idx)
344+
ArrayNode{T,D,S}(tree, child_idx)
366345
end, Val(Int(d)))
367346
end
368347

@@ -377,10 +356,10 @@ function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S}
377356
if isa(child, ArrayNode)
378357
getfield(child, :idx)
379358
else
380-
Int8(-1)
359+
UInt16(0)
381360
end
382361
else
383-
Int8(-1)
362+
UInt16(0)
384363
end
385364
end, Val(D))
386365
return tree.nodes.children[idx] = child_indices
@@ -396,12 +375,19 @@ function copy_node(n::ArrayNode{T,D,S}; break_sharing::Val{BS}=Val(false)) where
396375
# Add some buffer space
397376
tree_size = max(32, node_count * 2)
398377

399-
# Create new tree for the copy
400-
new_tree = ArrayTree{T,D}(tree_size)
378+
# Determine the array type from the existing tree's nodes
379+
# Default to Vector since that's the most common case
380+
# For other array types, we'd need more sophisticated type extraction
381+
new_tree = if tree.nodes.degree isa Vector
382+
ArrayTree{T,D}(tree_size; array_type=Vector)
383+
else
384+
# For other array types like FixedSizeVector, we just use default
385+
ArrayTree{T,D}(tree_size)
386+
end
401387
new_idx = copy_subtree!(new_tree, tree, idx)
402388
new_tree.root_idx = new_idx
403389

404-
return ArrayNode{T,D,typeof(new_tree.nodes)}(new_tree, new_idx)
390+
return ArrayNode{T,D,S}(new_tree, new_idx)
405391
end
406392

407393
Base.copy(n::ArrayNode) = copy_node(n)
@@ -485,7 +471,7 @@ function set_node!(dst::ArrayNode, src::ArrayNode)
485471
i -> begin
486472
if i <= src.degree
487473
child_idx = @inbounds src_tree.nodes.children[src_idx][i]
488-
if child_idx >= 0
474+
if child_idx > 0
489475
if dst_tree === src_tree
490476
# Same tree
491477
child_idx
@@ -494,10 +480,10 @@ function set_node!(dst::ArrayNode, src::ArrayNode)
494480
copy_subtree!(dst_tree, src_tree, child_idx)
495481
end
496482
else
497-
Int8(-1)
483+
UInt16(0)
498484
end
499485
else
500-
Int8(-1)
486+
UInt16(0)
501487
end
502488
end,
503489
Val(D),
@@ -522,7 +508,7 @@ function tree_mapreduce(
522508
return mapreduce_impl(f, op, tree, getfield(n, :idx))
523509
end
524510

525-
function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::Int8) where {F,G,T,D,S}
511+
function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,G,T,D,S}
526512
degree = @inbounds tree.nodes.degree[idx]
527513
node = ArrayNode{T,D,S}(tree, idx)
528514
result = f(node)
@@ -531,7 +517,7 @@ function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::Int8) where {F
531517
child_results = ntuple(
532518
i -> begin
533519
child_idx = @inbounds tree.nodes.children[idx][i]
534-
if child_idx >= 0
520+
if child_idx > 0
535521
mapreduce_impl(f, op, tree, child_idx)
536522
else
537523
nothing
@@ -554,14 +540,14 @@ function any(f::F, n::ArrayNode) where {F<:Function}
554540
return any_impl(f, tree, getfield(n, :idx))
555541
end
556542

557-
function any_impl(f::F, tree::ArrayTree{T,D,S}, idx::Int8) where {F,T,D,S}
543+
function any_impl(f::F, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,T,D,S}
558544
node = ArrayNode{T,D,S}(tree, idx)
559545
f(node) && return true
560546

561547
degree = @inbounds tree.nodes.degree[idx]
562548
for i in 1:degree
563549
child_idx = @inbounds tree.nodes.children[idx][i]
564-
if child_idx >= 0 && any_impl(f, tree, child_idx)
550+
if child_idx > 0 && any_impl(f, tree, child_idx)
565551
return true
566552
end
567553
end

0 commit comments

Comments
 (0)