Skip to content

Commit 214f3b3

Browse files
authored
Merge pull request #759 from jonas-schulze/vector-trie
Allow non-string indices for Trie
2 parents 529a10d + f62c8b1 commit 214f3b3

File tree

3 files changed

+73
-43
lines changed

3 files changed

+73
-43
lines changed

src/trie.jl

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,56 @@
1-
mutable struct Trie{T}
2-
value::T
3-
children::Dict{Char,Trie{T}}
1+
mutable struct Trie{K,V}
2+
value::V
3+
children::Dict{K,Trie{K,V}}
44
is_key::Bool
55

6-
function Trie{T}() where T
7-
self = new{T}()
8-
self.children = Dict{Char,Trie{T}}()
6+
function Trie{K,V}() where {K,V}
7+
self = new{K,V}()
8+
self.children = Dict{K,Trie{K,V}}()
99
self.is_key = false
1010
return self
1111
end
1212

13-
function Trie{T}(ks, vs) where T
14-
t = Trie{T}()
15-
for (k, v) in zip(ks, vs)
16-
t[k] = v
17-
end
18-
return t
13+
function Trie{K,V}(ks, vs) where {K,V}
14+
return Trie{K,V}(zip(ks, vs))
1915
end
2016

21-
function Trie{T}(kv) where T
22-
t = Trie{T}()
17+
function Trie{K,V}(kv) where {K,V}
18+
t = Trie{K,V}()
2319
for (k,v) in kv
2420
t[k] = v
2521
end
2622
return t
2723
end
2824
end
2925

30-
Trie() = Trie{Any}()
31-
Trie(ks::AbstractVector{K}, vs::AbstractVector{V}) where {K<:AbstractString,V} = Trie{V}(ks, vs)
32-
Trie(kv::AbstractVector{Tuple{K,V}}) where {K<:AbstractString,V} = Trie{V}(kv)
33-
Trie(kv::AbstractDict{K,V}) where {K<:AbstractString,V} = Trie{V}(kv)
34-
Trie(ks::AbstractVector{K}) where {K<:AbstractString} = Trie{Nothing}(ks, similar(ks, Nothing))
26+
Trie() = Trie{Any,Any}()
27+
Trie(ks::AbstractVector{K}, vs::AbstractVector{V}) where {K,V} = Trie{eltype(K),V}(ks, vs)
28+
Trie(kv::AbstractVector{Tuple{K,V}}) where {K,V} = Trie{eltype(K),V}(kv)
29+
Trie(kv::AbstractDict{K,V}) where {K,V} = Trie{eltype(K),V}(kv)
30+
Trie(ks::AbstractVector{K}) where {K} = Trie{eltype(K),Nothing}(ks, similar(ks, Nothing))
3531

36-
function Base.setindex!(t::Trie{T}, val, key::AbstractString) where T
37-
value = convert(T, val) # we don't want to iterate before finding out it fails
32+
function Base.setindex!(t::Trie{K,V}, val, key) where {K,V}
33+
value = convert(V, val) # we don't want to iterate before finding out it fails
3834
node = t
3935
for char in key
4036
if !haskey(node.children, char)
41-
node.children[char] = Trie{T}()
37+
node.children[char] = Trie{K,V}()
4238
end
4339
node = node.children[char]
4440
end
4541
node.is_key = true
4642
node.value = value
4743
end
4844

49-
function Base.getindex(t::Trie, key::AbstractString)
45+
function Base.getindex(t::Trie, key)
5046
node = subtrie(t, key)
5147
if node != nothing && node.is_key
5248
return node.value
5349
end
5450
throw(KeyError("key not found: $key"))
5551
end
5652

57-
function subtrie(t::Trie, prefix::AbstractString)
53+
function subtrie(t::Trie, prefix)
5854
node = t
5955
for char in prefix
6056
if !haskey(node.children, char)
@@ -66,30 +62,38 @@ function subtrie(t::Trie, prefix::AbstractString)
6662
return node
6763
end
6864

69-
function Base.haskey(t::Trie, key::AbstractString)
65+
function Base.haskey(t::Trie, key)
7066
node = subtrie(t, key)
7167
node != nothing && node.is_key
7268
end
7369

74-
function Base.get(t::Trie, key::AbstractString, notfound)
70+
function Base.get(t::Trie, key, notfound)
7571
node = subtrie(t, key)
7672
if node != nothing && node.is_key
7773
return node.value
7874
end
7975
return notfound
8076
end
8177

82-
function Base.keys(t::Trie, prefix::AbstractString="", found=AbstractString[])
78+
_concat(prefix::String, char::Char) = string(prefix, char)
79+
_concat(prefix::Vector{T}, char::T) where {T} = vcat(prefix, char)
80+
81+
_empty_prefix(::Trie{Char,V}) where {V} = ""
82+
_empty_prefix(::Trie{K,V}) where {K,V} = K[]
83+
84+
function Base.keys(t::Trie{K,V},
85+
prefix=_empty_prefix(t),
86+
found=Vector{typeof(prefix)}()) where {K,V}
8387
if t.is_key
8488
push!(found, prefix)
8589
end
8690
for (char,child) in t.children
87-
keys(child, string(prefix,char), found)
91+
keys(child, _concat(prefix, char), found)
8892
end
8993
return found
9094
end
9195

92-
function keys_with_prefix(t::Trie, prefix::AbstractString)
96+
function keys_with_prefix(t::Trie, prefix)
9397
st = subtrie(t, prefix)
9498
st != nothing ? keys(st,prefix) : []
9599
end
@@ -101,7 +105,7 @@ end
101105
# see the comments and implementation below for details.
102106
struct TrieIterator
103107
t::Trie
104-
str::AbstractString
108+
str
105109
end
106110

107111
# At the start, there is no previous iteration,
@@ -120,11 +124,11 @@ function Base.iterate(it::TrieIterator, (t, i) = (it.t, 0))
120124
end
121125
end
122126

123-
partial_path(t::Trie, str::AbstractString) = TrieIterator(t, str)
127+
partial_path(t::Trie, str) = TrieIterator(t, str)
124128
Base.IteratorSize(::Type{TrieIterator}) = Base.SizeUnknown()
125129

126130
"""
127-
find_prefixes(t::Trie, str::AbstractString)
131+
find_prefixes(t::Trie, str)
128132
129133
Find all keys from the `Trie` that are prefix of the given string
130134
@@ -137,10 +141,24 @@ julia> find_prefixes(t, "ABCDE")
137141
"A"
138142
"ABC"
139143
"ABCD"
144+
145+
julia> t′ = Trie([1:1, 1:3, 1:4, 2:4]);
146+
147+
julia> find_prefixes(t′, 1:5)
148+
3-element Vector{UnitRange{Int64}}:
149+
1:1
150+
1:3
151+
1:4
152+
153+
julia> find_prefixes(t′, [1,2,3,4,5])
154+
3-element Vector{Vector{Int64}}:
155+
[1]
156+
[1, 2, 3]
157+
[1, 2, 3, 4]
140158
```
141159
"""
142-
function find_prefixes(t::Trie, str::AbstractString)
143-
prefixes = AbstractString[]
160+
function find_prefixes(t::Trie, str::T) where {T}
161+
prefixes = T[]
144162
it = partial_path(t, str)
145163
idx = 0
146164
for t in it
@@ -150,4 +168,4 @@ function find_prefixes(t::Trie, str::AbstractString)
150168
idx = nextind(str, idx)
151169
end
152170
return prefixes
153-
end
171+
end

test/test_deprecations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# These are the tests for deprecated features, they should be deleted along with them
22

33
@testset "Trie: path iterator" begin
4-
t = Trie{Int}()
4+
t = Trie{Char,Int}()
55
t["rob"] = 27
66
t["roger"] = 52
77
t["kevin"] = Int8(11)

test/test_trie.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testset "Trie" begin
22
@testset "Core Functionality" begin
3-
t = Trie{Int}()
3+
t = Trie{Char,Int}()
44
t["amy"] = 56
55
t["ann"] = 15
66
t["emma"] = 30
@@ -19,14 +19,14 @@
1919
ks = ["amy", "ann", "emma", "rob", "roger"]
2020
vs = [56, 15, 30, 27, 52]
2121
kvs = collect(zip(ks, vs))
22-
@test isa(Trie(ks, vs), Trie{Int})
23-
@test isa(Trie(kvs), Trie{Int})
24-
@test isa(Trie(Dict(kvs)), Trie{Int})
25-
@test isa(Trie(ks), Trie{Nothing})
22+
@test isa(Trie(ks, vs), Trie{Char,Int})
23+
@test isa(Trie(kvs), Trie{Char,Int})
24+
@test isa(Trie(Dict(kvs)), Trie{Char,Int})
25+
@test isa(Trie(ks), Trie{Char,Nothing})
2626
end
2727

2828
@testset "partial_path iterator" begin
29-
t = Trie{Int}()
29+
t = Trie{Char,Int}()
3030
t["rob"] = 27
3131
t["roger"] = 52
3232
t["kevin"] = Int8(11)
@@ -53,7 +53,7 @@
5353
@test collect(partial_path(t, "東京")) == [t0, t1, t2]
5454
@test collect(partial_path(t, "東京スカイツリー")) == [t0, t1, t2]
5555
end
56-
56+
5757
@testset "find_prefixes" begin
5858
t = Trie(["A", "ABC", "ABD", "BCD"])
5959
prefixes = find_prefixes(t, "ABCDE")
@@ -66,4 +66,16 @@
6666
@test prefixes == ["東京都", "東京都渋谷区"]
6767
end
6868

69+
@testset "non-string indexing" begin
70+
t = Trie{Int,Int}()
71+
t[[1,2,3,4]] = 1
72+
t[[1,2]] = 2
73+
@test haskey(t, [1,2])
74+
@test get(t, [1,2], nothing) == 2
75+
st = subtrie(t, [1,2,3])
76+
@test keys(st) == [[4]]
77+
@test st[[4]] == 1
78+
@test find_prefixes(t, [1,2,3,5]) == [[1,2]]
79+
@test find_prefixes(t, 1:3) == [1:2]
80+
end
6981
end # @testset Trie

0 commit comments

Comments
 (0)