Skip to content

Commit 363841d

Browse files
committed
Fix type instabilities in AVL tree
Improves benchmark results for #762 v0.18.10 100000 elements DataStructures.AVLTree 6.061 s (7102520 allocations: 114.48 MiB) 1000 elements DataStructures.AVLTree 36.800 ms (28637 allocations: 509.97 KiB) 10 elements DataStructures.AVLTree 134.361 μs (102 allocations: 2.23 KiB) v0.19.0-DEV 100000 elements DataStructures.AVLTree 6.113 s (7102520 allocations: 114.48 MiB) 1000 elements DataStructures.AVLTree 36.298 ms (28637 allocations: 509.97 KiB) 10 elements DataStructures.AVLTree 130.629 μs (102 allocations: 2.23 KiB) Now: 100000 elements DataStructures.AVLTree 213.450 ms (200002 allocations: 9.16 MiB) 1000 elements DataStructures.AVLTree 1.408 ms (2002 allocations: 93.80 KiB) 10 elements DataStructures.AVLTree 7.814 μs (22 allocations: 1008 bytes)
1 parent 62105c0 commit 363841d

File tree

1 file changed

+103
-86
lines changed

1 file changed

+103
-86
lines changed

src/avl_tree.jl

Lines changed: 103 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,24 @@ AVLTreeNode(d) = AVLTreeNode{Any}(d)
1616

1717
AVLTreeNode_or_null{T} = Union{AVLTreeNode{T}, Nothing}
1818

19+
_getproperty(x::Nothing, f) = @assert false
20+
_getproperty(x::AVLTreeNode{T}, f) where {T} = getfield(x, f)
21+
Base.getproperty(x::AVLTreeNode_or_null{T}, f::Symbol) where {T} =
22+
_getproperty(x, f)
23+
24+
_setproperty!(x::Nothing, f, v) = @assert false
25+
_setproperty!(x::AVLTreeNode{T}, f, v) where {T} =
26+
# setfield!(x, f, convert(fieldtype(typeof(x), f), v))
27+
setfield!(x, f, v)
28+
_setproperty!(x::AVLTreeNode{T}, f, ::Nothing) where {T} =
29+
setfield!(x, f, nothing)
30+
_setproperty!(x::AVLTreeNode{T}, f, v::AVLTreeNode{T}) where {T} =
31+
setfield!(x, f, v)
32+
Base.setproperty!(x::AVLTreeNode_or_null{T}, f::Symbol, v) where {T} =
33+
_setproperty!(x, f, v)
34+
Base.setproperty!(x::AVLTreeNode_or_null{T}, f::Symbol, v::AVLTreeNode_or_null{T}) where {T} =
35+
_setproperty!(x, f, v)
36+
1937
mutable struct AVLTree{T}
2038
root::AVLTreeNode_or_null{T}
2139
count::Int
@@ -27,7 +45,7 @@ AVLTree() = AVLTree{Any}()
2745

2846
Base.length(tree::AVLTree) = tree.count
2947

30-
get_height(node::Union{AVLTreeNode, Nothing}) = (node == nothing) ? 0 : node.height
48+
get_height(node::Union{AVLTreeNode, Nothing}) = (node == nothing) ? Int32(0) : node.height
3149

3250
# balance is the difference of height between leftChild and rightChild of a node.
3351
function get_balance(node::Union{AVLTreeNode, Nothing})
@@ -40,25 +58,25 @@ end
4058

4159
# computes the height of the subtree, which basically is
4260
# one added the maximum of the height of the left subtree and right subtree
43-
compute_height(node::AVLTreeNode) = 1 + max(get_height(node.leftChild), get_height(node.rightChild))
61+
compute_height(node::AVLTreeNode) = Int8(1 + max(get_height(node.leftChild), get_height(node.rightChild)))
4462

45-
get_subsize(node::AVLTreeNode_or_null) = (node == nothing) ? 0 : node.subsize
63+
get_subsize(node::AVLTreeNode_or_null) = (node == nothing) ? Int32(0) : node.subsize
4664

4765
# compute the subtree size
4866
function compute_subtree_size(node::AVLTreeNode_or_null)
4967
if node == nothing
50-
return 0
68+
return Int32(0)
5169
else
5270
L = get_subsize(node.leftChild)
5371
R = get_subsize(node.rightChild)
54-
return (L + R + 1)
72+
return (L + R + Int32(1))
5573
end
5674
end
5775

5876
"""
5977
left_rotate(node_x::AVLTreeNode)
6078
61-
Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
79+
Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
6280
"""
6381
function left_rotate(z::AVLTreeNode)
6482
y = z.rightChild
@@ -75,7 +93,7 @@ end
7593
"""
7694
right_rotate(node_x::AVLTreeNode)
7795
78-
Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
96+
Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
7997
"""
8098
function right_rotate(z::AVLTreeNode)
8199
y = z.leftChild
@@ -90,9 +108,9 @@ function right_rotate(z::AVLTreeNode)
90108
end
91109

92110
"""
93-
minimum_node(tree::AVLTree, node::AVLTreeNode)
111+
minimum_node(tree::AVLTree, node::AVLTreeNode)
94112
95-
Returns the AVLTreeNode with minimum value in subtree of `node`.
113+
Returns the AVLTreeNode with minimum value in subtree of `node`.
96114
"""
97115
function minimum_node(node::Union{AVLTreeNode, Nothing})
98116
while node != nothing && node.leftChild != nothing
@@ -107,60 +125,60 @@ function search_node(tree::AVLTree{K}, d::K) where K
107125
while node != nothing && node.data != nothing && node.data != d
108126

109127
prev = node
110-
if d < node.data
128+
if d < node.data
111129
node = node.leftChild
112130
else
113131
node = node.rightChild
114132
end
115133
end
116-
134+
117135
return (node == nothing) ? prev : node
118136
end
119137

120-
function Base.haskey(tree::AVLTree{K}, d::K) where K
138+
function Base.haskey(tree::AVLTree{K}, d::K) where K
121139
(tree.root == nothing) && return false
122140
node = search_node(tree, d)
123141
return (node.data == d)
124142
end
125143

126144
Base.in(key, tree::AVLTree) = haskey(tree, key)
127145

128-
function Base.insert!(tree::AVLTree{K}, d::K) where K
146+
function insert_node(node::Nothing, key::K) where K
147+
return AVLTreeNode{K}(key)
148+
end
149+
function insert_node(node::AVLTreeNode{K}, key::K) where K
150+
if key < node.data
151+
node.leftChild = insert_node(node.leftChild, key)
152+
else
153+
node.rightChild = insert_node(node.rightChild, key)
154+
end
155+
156+
node.subsize = compute_subtree_size(node)
157+
node.height = compute_height(node)
158+
balance = get_balance(node)
129159

130-
function insert_node(node::Union{AVLTreeNode, Nothing}, key)
131-
if node == nothing
132-
return AVLTreeNode{K}(key)
133-
elseif key < node.data
134-
node.leftChild = insert_node(node.leftChild, key)
160+
if balance > 1
161+
if key < node.leftChild.data
162+
return right_rotate(node)
135163
else
136-
node.rightChild = insert_node(node.rightChild, key)
137-
end
138-
139-
node.subsize = compute_subtree_size(node)
140-
node.height = compute_height(node)
141-
balance = get_balance(node)
142-
143-
if balance > 1
144-
if key < node.leftChild.data
145-
return right_rotate(node)
146-
else
147-
node.leftChild = left_rotate(node.leftChild)
148-
return right_rotate(node)
149-
end
164+
node.leftChild = left_rotate(node.leftChild)
165+
return right_rotate(node)
150166
end
167+
end
151168

152-
if balance < -1
153-
if key > node.rightChild.data
154-
return left_rotate(node)
155-
else
156-
node.rightChild = right_rotate(node.rightChild)
157-
return left_rotate(node)
158-
end
169+
if balance < -1
170+
if key > node.rightChild.data
171+
return left_rotate(node)
172+
else
173+
node.rightChild = right_rotate(node.rightChild)
174+
return left_rotate(node)
159175
end
160-
161-
return node
162176
end
163177

178+
return node
179+
end
180+
181+
function Base.insert!(tree::AVLTree{K}, d::K) where K
164182
haskey(tree, d) && return tree
165183

166184
tree.root = insert_node(tree.root, d)
@@ -173,55 +191,54 @@ function Base.push!(tree::AVLTree{K}, key0) where K
173191
insert!(tree, key)
174192
end
175193

176-
function Base.delete!(tree::AVLTree{K}, d::K) where K
177-
178-
function delete_node!(node::Union{AVLTreeNode, Nothing}, key)
179-
if key < node.data
180-
node.leftChild = delete_node!(node.leftChild, key)
181-
elseif key > node.data
182-
node.rightChild = delete_node!(node.rightChild, key)
194+
function delete_node!(node::AVLTreeNode{K}, key) where K
195+
if key < node.data
196+
node.leftChild = delete_node!(node.leftChild, key)
197+
elseif key > node.data
198+
node.rightChild = delete_node!(node.rightChild, key)
199+
else
200+
if node.leftChild == nothing
201+
result = node.rightChild
202+
return result
203+
elseif node.rightChild == nothing
204+
result = node.leftChild
205+
return result
183206
else
184-
if node.leftChild == nothing
185-
result = node.rightChild
186-
return result
187-
elseif node.rightChild == nothing
188-
result = node.leftChild
189-
return result
190-
else
191-
result = minimum_node(node.rightChild)
192-
node.data = result.data
193-
node.rightChild = delete_node!(node.rightChild, result.data)
194-
end
207+
result = minimum_node(node.rightChild)
208+
node.data = result.data
209+
node.rightChild = delete_node!(node.rightChild, result.data)
195210
end
196-
197-
node.subsize = compute_subtree_size(node)
198-
node.height = compute_height(node)
199-
balance = get_balance(node)
200-
201-
if balance > 1
202-
if get_balance(node.leftChild) >= 0
203-
return right_rotate(node)
204-
else
205-
node.leftChild = left_rotate(node.leftChild)
206-
return right_rotate(node)
207-
end
211+
end
212+
213+
node.subsize = compute_subtree_size(node)
214+
node.height = compute_height(node)
215+
balance = get_balance(node)
216+
217+
if balance > 1
218+
if get_balance(node.leftChild) >= 0
219+
return right_rotate(node)
220+
else
221+
node.leftChild = left_rotate(node.leftChild)
222+
return right_rotate(node)
208223
end
224+
end
209225

210-
if balance < -1
211-
if get_balance(node.rightChild) <= 0
212-
return left_rotate(node)
213-
else
214-
node.rightChild = right_rotate(node.rightChild)
215-
return left_rotate(node)
216-
end
217-
end
218-
219-
return node
226+
if balance < -1
227+
if get_balance(node.rightChild) <= 0
228+
return left_rotate(node)
229+
else
230+
node.rightChild = right_rotate(node.rightChild)
231+
return left_rotate(node)
232+
end
220233
end
221234

235+
return node
236+
end
237+
238+
function Base.delete!(tree::AVLTree{K}, d::K) where K
222239
# if the key is not in the tree, do nothing and return the tree
223240
!haskey(tree, d) && return tree
224-
241+
225242
# if the key is present, delete it from the tree
226243
tree.root = delete_node!(tree.root, d)
227244
tree.count -= 1
@@ -244,12 +261,12 @@ function sorted_rank(tree::AVLTree{K}, key::K) where K
244261
else
245262
node = node.leftChild
246263
end
247-
end
264+
end
248265
rank += (1 + get_subsize(node.leftChild))
249266
return rank
250267
end
251268

252-
function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
269+
function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
253270
@boundscheck (1 <= ind <= tree.count) || throw(BoundsError("$ind should be in between 1 and $(tree.count)"))
254271
function traverse_tree(node::AVLTreeNode_or_null, idx)
255272
if (node != nothing)
@@ -263,6 +280,6 @@ function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
263280
end
264281
end
265282
end
266-
value = traverse_tree(tree.root, ind)
283+
value = traverse_tree(tree.root, ind)
267284
return value
268-
end
285+
end

0 commit comments

Comments
 (0)