Skip to content

Commit d4e6491

Browse files
authored
Merge pull request #766 from goerch/master
Fix type instabilities in AVL, RB and splay tree
2 parents b2c6a95 + a6c7bac commit d4e6491

File tree

5 files changed

+121
-112
lines changed

5 files changed

+121
-112
lines changed

src/avl_tree.jl

Lines changed: 88 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ end
1414

1515
AVLTreeNode(d) = AVLTreeNode{Any}(d)
1616

17+
Base.setproperty!(x::AVLTreeNode{T}, f::Symbol, v) where {T} =
18+
setfield!(x, f, v)
19+
1720
AVLTreeNode_or_null{T} = Union{AVLTreeNode{T}, Nothing}
1821

1922
mutable struct AVLTree{T}
@@ -27,7 +30,7 @@ AVLTree() = AVLTree{Any}()
2730

2831
Base.length(tree::AVLTree) = tree.count
2932

30-
get_height(node::Union{AVLTreeNode, Nothing}) = (node == nothing) ? 0 : node.height
33+
get_height(node::Union{AVLTreeNode, Nothing}) = (node == nothing) ? Int8(0) : node.height
3134

3235
# balance is the difference of height between leftChild and rightChild of a node.
3336
function get_balance(node::Union{AVLTreeNode, Nothing})
@@ -40,25 +43,25 @@ end
4043

4144
# computes the height of the subtree, which basically is
4245
# one added the maximum of the height of the left subtree and right subtree
43-
compute_height(node::AVLTreeNode) = 1 + max(get_height(node.leftChild), get_height(node.rightChild))
46+
compute_height(node::AVLTreeNode) = Int8(1) + max(get_height(node.leftChild), get_height(node.rightChild))
4447

45-
get_subsize(node::AVLTreeNode_or_null) = (node == nothing) ? 0 : node.subsize
48+
get_subsize(node::AVLTreeNode_or_null) = (node == nothing) ? Int32(0) : node.subsize
4649

4750
# compute the subtree size
4851
function compute_subtree_size(node::AVLTreeNode_or_null)
4952
if node == nothing
50-
return 0
53+
return Int32(0)
5154
else
5255
L = get_subsize(node.leftChild)
5356
R = get_subsize(node.rightChild)
54-
return (L + R + 1)
57+
return (L + R + Int32(1))
5558
end
5659
end
5760

5861
"""
5962
left_rotate(node_x::AVLTreeNode)
6063
61-
Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
64+
Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
6265
"""
6366
function left_rotate(z::AVLTreeNode)
6467
y = z.rightChild
@@ -75,7 +78,7 @@ end
7578
"""
7679
right_rotate(node_x::AVLTreeNode)
7780
78-
Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
81+
Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
7982
"""
8083
function right_rotate(z::AVLTreeNode)
8184
y = z.leftChild
@@ -90,9 +93,9 @@ function right_rotate(z::AVLTreeNode)
9093
end
9194

9295
"""
93-
minimum_node(tree::AVLTree, node::AVLTreeNode)
96+
minimum_node(tree::AVLTree, node::AVLTreeNode)
9497
95-
Returns the AVLTreeNode with minimum value in subtree of `node`.
98+
Returns the AVLTreeNode with minimum value in subtree of `node`.
9699
"""
97100
function minimum_node(node::Union{AVLTreeNode, Nothing})
98101
while node != nothing && node.leftChild != nothing
@@ -107,60 +110,60 @@ function search_node(tree::AVLTree{K}, d::K) where K
107110
while node != nothing && node.data != nothing && node.data != d
108111

109112
prev = node
110-
if d < node.data
113+
if d < node.data
111114
node = node.leftChild
112115
else
113116
node = node.rightChild
114117
end
115118
end
116-
119+
117120
return (node == nothing) ? prev : node
118121
end
119122

120-
function Base.haskey(tree::AVLTree{K}, d::K) where K
123+
function Base.haskey(tree::AVLTree{K}, d::K) where K
121124
(tree.root == nothing) && return false
122125
node = search_node(tree, d)
123126
return (node.data == d)
124127
end
125128

126129
Base.in(key, tree::AVLTree) = haskey(tree, key)
127130

128-
function Base.insert!(tree::AVLTree{K}, d::K) where K
131+
function insert_node(node::Nothing, key::K) where K
132+
return AVLTreeNode{K}(key)
133+
end
134+
function insert_node(node::AVLTreeNode{K}, key::K) where K
135+
if key < node.data
136+
node.leftChild = insert_node(node.leftChild, key)
137+
else
138+
node.rightChild = insert_node(node.rightChild, key)
139+
end
140+
141+
node.subsize = compute_subtree_size(node)
142+
node.height = compute_height(node)
143+
balance = get_balance(node)
129144

130-
function insert_node(node::Union{AVLTreeNode, Nothing}, key)
131-
if node == nothing
132-
return AVLTreeNode{K}(key)
133-
elseif key < node.data
134-
node.leftChild = insert_node(node.leftChild, key)
145+
if balance > 1
146+
if key < node.leftChild.data
147+
return right_rotate(node)
135148
else
136-
node.rightChild = insert_node(node.rightChild, key)
137-
end
138-
139-
node.subsize = compute_subtree_size(node)
140-
node.height = compute_height(node)
141-
balance = get_balance(node)
142-
143-
if balance > 1
144-
if key < node.leftChild.data
145-
return right_rotate(node)
146-
else
147-
node.leftChild = left_rotate(node.leftChild)
148-
return right_rotate(node)
149-
end
149+
node.leftChild = left_rotate(node.leftChild)
150+
return right_rotate(node)
150151
end
152+
end
151153

152-
if balance < -1
153-
if key > node.rightChild.data
154-
return left_rotate(node)
155-
else
156-
node.rightChild = right_rotate(node.rightChild)
157-
return left_rotate(node)
158-
end
154+
if balance < -1
155+
if key > node.rightChild.data
156+
return left_rotate(node)
157+
else
158+
node.rightChild = right_rotate(node.rightChild)
159+
return left_rotate(node)
159160
end
160-
161-
return node
162161
end
163162

163+
return node
164+
end
165+
166+
function Base.insert!(tree::AVLTree{K}, d::K) where K
164167
haskey(tree, d) && return tree
165168

166169
tree.root = insert_node(tree.root, d)
@@ -173,55 +176,54 @@ function Base.push!(tree::AVLTree{K}, key0) where K
173176
insert!(tree, key)
174177
end
175178

176-
function Base.delete!(tree::AVLTree{K}, d::K) where K
177-
178-
function delete_node!(node::Union{AVLTreeNode, Nothing}, key)
179-
if key < node.data
180-
node.leftChild = delete_node!(node.leftChild, key)
181-
elseif key > node.data
182-
node.rightChild = delete_node!(node.rightChild, key)
179+
function delete_node!(node::AVLTreeNode{K}, key::K) where K
180+
if key < node.data
181+
node.leftChild = delete_node!(node.leftChild, key)
182+
elseif key > node.data
183+
node.rightChild = delete_node!(node.rightChild, key)
184+
else
185+
if node.leftChild == nothing
186+
result = node.rightChild
187+
return result
188+
elseif node.rightChild == nothing
189+
result = node.leftChild
190+
return result
183191
else
184-
if node.leftChild == nothing
185-
result = node.rightChild
186-
return result
187-
elseif node.rightChild == nothing
188-
result = node.leftChild
189-
return result
190-
else
191-
result = minimum_node(node.rightChild)
192-
node.data = result.data
193-
node.rightChild = delete_node!(node.rightChild, result.data)
194-
end
192+
result = minimum_node(node.rightChild)
193+
node.data = result.data
194+
node.rightChild = delete_node!(node.rightChild, result.data)
195195
end
196-
197-
node.subsize = compute_subtree_size(node)
198-
node.height = compute_height(node)
199-
balance = get_balance(node)
200-
201-
if balance > 1
202-
if get_balance(node.leftChild) >= 0
203-
return right_rotate(node)
204-
else
205-
node.leftChild = left_rotate(node.leftChild)
206-
return right_rotate(node)
207-
end
196+
end
197+
198+
node.subsize = compute_subtree_size(node)
199+
node.height = compute_height(node)
200+
balance = get_balance(node)
201+
202+
if balance > 1
203+
if get_balance(node.leftChild) >= 0
204+
return right_rotate(node)
205+
else
206+
node.leftChild = left_rotate(node.leftChild)
207+
return right_rotate(node)
208208
end
209+
end
209210

210-
if balance < -1
211-
if get_balance(node.rightChild) <= 0
212-
return left_rotate(node)
213-
else
214-
node.rightChild = right_rotate(node.rightChild)
215-
return left_rotate(node)
216-
end
217-
end
218-
219-
return node
211+
if balance < -1
212+
if get_balance(node.rightChild) <= 0
213+
return left_rotate(node)
214+
else
215+
node.rightChild = right_rotate(node.rightChild)
216+
return left_rotate(node)
217+
end
220218
end
221219

220+
return node
221+
end
222+
223+
function Base.delete!(tree::AVLTree{K}, d::K) where K
222224
# if the key is not in the tree, do nothing and return the tree
223225
!haskey(tree, d) && return tree
224-
226+
225227
# if the key is present, delete it from the tree
226228
tree.root = delete_node!(tree.root, d)
227229
tree.count -= 1
@@ -244,12 +246,12 @@ function sorted_rank(tree::AVLTree{K}, key::K) where K
244246
else
245247
node = node.leftChild
246248
end
247-
end
249+
end
248250
rank += (1 + get_subsize(node.leftChild))
249251
return rank
250252
end
251253

252-
function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
254+
function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
253255
@boundscheck (1 <= ind <= tree.count) || throw(BoundsError("$ind should be in between 1 and $(tree.count)"))
254256
function traverse_tree(node::AVLTreeNode_or_null, idx)
255257
if (node != nothing)
@@ -263,6 +265,6 @@ function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
263265
end
264266
end
265267
end
266-
value = traverse_tree(tree.root, ind)
268+
value = traverse_tree(tree.root, ind)
267269
return value
268-
end
270+
end

src/red_black_tree.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ end
1717
RBTreeNode() = RBTreeNode{Any}()
1818
RBTreeNode(d) = RBTreeNode{Any}(d)
1919

20+
Base.setproperty!(x::RBTreeNode{K}, f::Symbol, v) where {K} =
21+
setfield!(x, f, v)
22+
2023
function create_null_node(K::Type)
2124
node = RBTreeNode{K}()
2225
node.color = false

src/splay_tree.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ end
1111
SplayTreeNode(d) = SplayTreeNode{Any}(d)
1212
SplayTreeNode() = SplayTreeNode{Any}()
1313

14+
Base.setproperty!(x::SplayTreeNode{K}, f::Symbol, v) where {K} =
15+
setfield!(x, f, v)
16+
1417
mutable struct SplayTree{K}
1518
root::Union{SplayTreeNode{K}, Nothing}
1619
count::Int
1720

1821
SplayTree{K}() where K = new{K}(nothing, 0)
19-
end
22+
end
2023

2124
Base.length(tree::SplayTree) = tree.count
2225

@@ -41,7 +44,7 @@ function left_rotate!(tree::SplayTree, node_x::SplayTreeNode)
4144
node_y.leftChild = node_x
4245
end
4346
node_x.parent = node_y
44-
end
47+
end
4548

4649
function right_rotate!(tree::SplayTree, node_x::SplayTreeNode)
4750
node_y = node_x.leftChild
@@ -59,7 +62,7 @@ function right_rotate!(tree::SplayTree, node_x::SplayTreeNode)
5962
end
6063
node_y.rightChild = node_x
6164
node_x.parent = node_y
62-
end
65+
end
6366

6467
# The splaying operation moves node_x to the root of the tree using the series of rotations.
6568
function splay!(tree::SplayTree, node_x::SplayTreeNode)
@@ -71,7 +74,7 @@ function splay!(tree::SplayTree, node_x::SplayTreeNode)
7174
if node_x == parent.leftChild
7275
# zig rotation
7376
right_rotate!(tree, node_x.parent)
74-
else
77+
else
7578
# zag rotation
7679
left_rotate!(tree, node_x.parent)
7780
end
@@ -104,7 +107,7 @@ function maximum_node(node::Union{SplayTreeNode, Nothing})
104107
return node
105108
end
106109

107-
# Join operations joins two trees S and T
110+
# Join operations joins two trees S and T
108111
# All the items in S are smaller than the items in T.
109112
# This is a two-step process.
110113
# In the first step, splay the largest node in S. This moves the largest node to the root node.
@@ -157,10 +160,10 @@ function Base.delete!(tree::SplayTree{K}, d::K) where K
157160
x = search_node(tree, d)
158161
(x == nothing) && return tree
159162
t = nothing
160-
s = nothing
161-
163+
s = nothing
164+
162165
splay!(tree, x)
163-
166+
164167
if x.rightChild !== nothing
165168
t = x.rightChild
166169
t.parent = nothing
@@ -211,7 +214,7 @@ function Base.push!(tree::SplayTree{K}, d0) where K
211214
return tree
212215
end
213216

214-
function Base.getindex(tree::SplayTree{K}, ind) where K
217+
function Base.getindex(tree::SplayTree{K}, ind) where K
215218
@boundscheck (1 <= ind <= tree.count) || throw(KeyError("$ind should be in between 1 and $(tree.count)"))
216219
function traverse_tree_inorder(node::Union{SplayTreeNode, Nothing})
217220
if (node != nothing)
@@ -222,6 +225,6 @@ function Base.getindex(tree::SplayTree{K}, ind) where K
222225
return K[]
223226
end
224227
end
225-
arr = traverse_tree_inorder(tree.root)
228+
arr = traverse_tree_inorder(tree.root)
226229
return @inbounds arr[ind]
227230
end

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using Serialization
55

66
import DataStructures: IntSet
77

8-
@test [] == detect_ambiguities(Base, Core, DataStructures)
8+
@test [] == detect_ambiguities(Core, DataStructures)
9+
@test [] == detect_ambiguities(Base, DataStructures)
910

1011
tests = ["deprecations",
1112
"int_set",
@@ -34,7 +35,7 @@ tests = ["deprecations",
3435
"dibit_vector",
3536
"swiss_dict",
3637
"avl_tree",
37-
"red_black_tree",
38+
"red_black_tree",
3839
"splay_tree"
3940
]
4041

0 commit comments

Comments
 (0)