Skip to content

Commit 8fdf15a

Browse files
committed
add methods, write tests
1 parent 4fcce01 commit 8fdf15a

File tree

4 files changed

+187
-30
lines changed

4 files changed

+187
-30
lines changed

src/DataStructures.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ module DataStructures
107107
include("dibit_vector.jl")
108108
include("avl_tree.jl")
109109
include("red_black_tree.jl")
110-
110+
include("splay_tree.jl")
111111
include("deprecations.jl")
112+
112113
end

src/splay_tree.jl

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ SplayTreeNode() = SplayTreeNode{Any}()
1616

1717
mutable struct SplayTree{K}
1818
root::SplayTreeNode_or_null{K}
19+
count::Int
1920

20-
SplayTree{K}() where K = new{K}(nothing)
21+
SplayTree{K}() where K = new{K}(nothing, 0)
2122
end
2223

23-
SplayTree(d) = SplayTree{Any}(d)
24+
Base.length(tree::SplayTree) = tree.count
25+
2426
SplayTree() = SplayTree{Any}()
2527

2628
function left_rotate!(tree::SplayTree, node_x::SplayTreeNode)
@@ -98,8 +100,9 @@ function splay!(tree::SplayTree, node_x::SplayTreeNode)
98100
end
99101
end
100102

101-
function maximum(node::SplayTreeNode)
102-
while !isa(node.rightChild, Nothing)
103+
function maximum_node(node::SplayTreeNode_or_null)
104+
(node == nothing) && return node
105+
while node.rightChild != nothing
103106
node = node.rightChild
104107
end
105108
return node
@@ -111,52 +114,48 @@ function _join(tree::SplayTree ,s::SplayTreeNode_or_null, t::SplayTreeNode_or_nu
111114
elseif isa(t, Nothing)
112115
return s
113116
else
114-
x = maximum(s)
117+
x = maximum_node(s)
115118
splay!(tree, x)
116119
x.rightChild = t
117120
t.parent = x
118121
return x
119122
end
120123
end
121124

122-
function search_by_node(node::SplayTreeNode_or_null{K}, d::K) where K
123-
while !isa(node, Nothing)
124-
if node.data == d
125-
break
126-
elseif node.data < d
127-
if !isa(node.rightChild, Nothing)
128-
node = node.rightChild
129-
else
130-
break
131-
end
125+
function search_node(tree::SplayTree{K}, d::K) where K
126+
node = tree.root
127+
prev = nothing
128+
while node != nothing && node.data != d
129+
prev = node
130+
if node.data < d
131+
node = node.rightChild
132132
else
133-
if isa(node.leftChild, Nothing)
134-
node = node.leftChild
135-
else
136-
break
137-
end
133+
node = node.leftChild
138134
end
139135
end
140-
return node
136+
return (node == nothing) ? prev : node
141137
end
142138

143-
function search_key(tree::SplayTree{K}, d::K) where K
139+
function haskey(tree::SplayTree{K}, d::K) where K
144140
node = tree.root
145141
if isa(node, Nothing)
146142
return false
147143
else
148-
node = search_by_node(node, d)
144+
node = search_node(tree, d)
149145
isa(node, Nothing) && return false
150146
is_found = (node.data == d)
151147
is_found && splay!(tree, node)
152148
return is_found
153149
end
154150
end
155151

152+
153+
Base.in(key, tree::SplayTree) = haskey(tree, key)
154+
156155
function Base.delete!(tree::SplayTree{K}, d::K) where K
157156
node = tree.root
158-
x = search_by_node(node, d)
159-
isa(x, Nothing) && return tree
157+
x = search_node(tree, d)
158+
(x == nothing) && return tree
160159
t = nothing
161160
s = nothing
162161

@@ -175,11 +174,12 @@ function Base.delete!(tree::SplayTree{K}, d::K) where K
175174
end
176175

177176
tree.root = _join(tree, s.leftChild, t)
177+
tree.count -= 1
178178
return tree
179179
end
180180

181181
function Base.insert!(tree::SplayTree{K}, d::K) where K
182-
is_present = search_by_node(tree.root, d)
182+
is_present = search_node(tree, d)
183183
if !isa(is_present, Nothing) && (is_present.data == d)
184184
return tree
185185
end
@@ -206,5 +206,26 @@ function Base.insert!(tree::SplayTree{K}, d::K) where K
206206
y.rightChild = node
207207
end
208208
splay!(tree, node)
209+
tree.count += 1
209210
return tree
210-
end
211+
end
212+
213+
function Base.push!(tree::SplayTree{K}, key0) where K
214+
key = convert(K, key0)
215+
insert!(tree, key)
216+
end
217+
218+
function Base.getindex(tree::SplayTree{K}, ind) where K
219+
@boundscheck (1 <= ind <= tree.count) || throw(BoundsError("$ind should be in between 1 and $(tree.count)"))
220+
function traverse_tree_inorder(node::SplayTreeNode_or_null)
221+
if (node != nothing)
222+
left = traverse_tree_inorder(node.leftChild)
223+
right = traverse_tree_inorder(node.rightChild)
224+
append!(push!(left, node.data), right)
225+
else
226+
return K[]
227+
end
228+
end
229+
arr = traverse_tree_inorder(tree.root)
230+
return @inbounds arr[ind]
231+
end

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ tests = ["deprecations",
3232
"robin_dict",
3333
"ordered_robin_dict",
3434
"dibit_vector",
35-
"red_black_tree",
3635
"swiss_dict",
37-
"avl_tree"
36+
"avl_tree",
37+
"red_black_tree",
38+
"splay_tree"
3839
]
3940

4041
if length(ARGS) > 0

test/test_splay_tree.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
@testset "SplayTree" begin
2+
@testset "inserting values" begin
3+
t = SplayTree{Int}()
4+
for i in 1:100
5+
insert!(t, i)
6+
end
7+
8+
@test length(t) == 100
9+
10+
for i in 1:100
11+
@test haskey(t, i)
12+
end
13+
14+
for i = 101:200
15+
@test !haskey(t, i)
16+
end
17+
end
18+
19+
@testset "deleting values" begin
20+
t = SplayTree{Int}()
21+
for i in 1:100
22+
insert!(t, i)
23+
end
24+
for i in 1:2:100
25+
delete!(t, i)
26+
end
27+
28+
@test length(t) == 50
29+
30+
for i in 1:100
31+
if iseven(i)
32+
@test haskey(t, i)
33+
else
34+
@test !haskey(t, i)
35+
end
36+
end
37+
38+
for i in 1:2:100
39+
insert!(t, i)
40+
end
41+
42+
@test length(t) == 100
43+
end
44+
45+
@testset "handling different cases of delete!" begin
46+
t2 = SplayTree()
47+
for i in 1:100000
48+
insert!(t2, i)
49+
end
50+
51+
@test length(t2) == 100000
52+
53+
nums = rand(1:100000, 8599)
54+
visited = Set()
55+
for num in nums
56+
if !(num in visited)
57+
delete!(t2, num)
58+
push!(visited, num)
59+
end
60+
end
61+
62+
for i in visited
63+
@test !haskey(t2, i)
64+
end
65+
@test (length(t2) + length(visited)) == 100000
66+
end
67+
68+
@testset "handling different cases of insert!" begin
69+
nums = rand(1:100000, 1000)
70+
t3 = SplayTree()
71+
uniq_nums = Set(nums)
72+
for num in nums
73+
insert!(t3, num)
74+
end
75+
@test length(t3) == length(uniq_nums)
76+
end
77+
78+
@testset "in" begin
79+
t4 = SplayTree{Char}()
80+
push!(t4, 'a')
81+
push!(t4, 'b')
82+
@test length(t4) == 2
83+
@test in('a', t4)
84+
@test !in('c', t4)
85+
end
86+
87+
@testset "search_node" begin
88+
t5 = SplayTree()
89+
for i in 1:32
90+
push!(t5, i)
91+
end
92+
n1 = search_node(t5, 21)
93+
@test n1.data == 21
94+
n2 = search_node(t5, 35)
95+
@test n2.data == 32
96+
n3 = search_node(t5, 0)
97+
@test n3.data == 1
98+
end
99+
100+
@testset "getindex" begin
101+
t6 = SplayTree{Int}()
102+
for i in 1:10
103+
push!(t6, i)
104+
end
105+
for i in 1:10
106+
@test t6[i] == i
107+
end
108+
@test_throws BoundsError getindex(t6, 0)
109+
@test_throws BoundsError getindex(t6, 11)
110+
end
111+
112+
@testset "key conversion in push!" begin
113+
t7 = SplayTree{Int}()
114+
push!(t7, Int8(1))
115+
@test length(t7) == 1
116+
@test haskey(t7, 1)
117+
end
118+
119+
@testset "maximum_node" begin
120+
t8 = SplayTree()
121+
@test maximum_node(t8.root) == nothing
122+
for i in 1:32
123+
push!(t8, i)
124+
end
125+
m1 = maximum_node(t8.root)
126+
@test m1.data == 32
127+
node = t8.root
128+
while node.rightChild != nothing
129+
m = maximum_node(node.rightChild)
130+
@test m == m1
131+
node = node.rightChild
132+
end
133+
end
134+
end

0 commit comments

Comments
 (0)