Skip to content

Commit 92a9149

Browse files
authored
DTableColumn improvements (#12)
1 parent 826abbe commit 92a9149

File tree

7 files changed

+95
-79
lines changed

7 files changed

+95
-79
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
matrix:
2222
version:
2323
- '1.7'
24+
- '1.8'
2425
- 'nightly'
2526
os:
2627
- ubuntu-latest

src/table/dataframes_interface.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import DataAPI: All, Between, BroadcastedSelector, Cols
2+
import DataFrames: AsTable, ByRow, ColumnIndex, MultiColumnIndex, normalize_selection
13
import InvertedIndices: BroadcastedInvertedIndex
2-
import DataAPI: Between, All, Cols, BroadcastedSelector
3-
import DataFrames: ColumnIndex, MultiColumnIndex,
4-
ByRow, AsTable, normalize_selection
4+
5+
56

67
make_pair_concrete(@nospecialize(x::Pair)) =
78
make_pair_concrete(x.first) => make_pair_concrete(x.second)

src/table/dtable.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ mutable struct DTable
2121
chunks::VTYPE
2222
tabletype
2323
schema::Union{Nothing,Tables.Schema}
24-
DTable(chunks::VTYPE, tabletype) = new(chunks, tabletype, nothing)
2524
end
2625

27-
DTable(chunks::Vector{Dagger.EagerThunk}, args...) = DTable(VTYPE(chunks), args...)
28-
DTable(chunks::Vector{Dagger.Chunk}, args...) = DTable(VTYPE(chunks), args...)
26+
DTable(chunks::Vector, tabletype) = DTable(VTYPE(chunks), tabletype, nothing)
27+
DTable(chunks::Vector, tabletype, schema) = DTable(VTYPE(chunks), tabletype, schema)
28+
29+
2930

3031
"""
3132
DTable(table; tabletype=nothing) -> DTable

src/table/dtable_column.jl

Lines changed: 58 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,87 @@
1-
21
mutable struct DTableColumn{T,TT}
32
dtable::DTable
4-
current_chunk::Int
53
col::Int
64
colname::Symbol
75
chunk_lengths::Vector{Int}
8-
current_iterator::Union{Nothing,TT}
9-
chunkstore::Union{Nothing,Vector{T}}
6+
_chunk::Int
7+
_iter::Union{Nothing,TT}
8+
_chunkstore::Union{Nothing,Vector{T}}
109
end
1110

12-
__ff = (ch, col) -> Tables.getcolumn(Tables.columns(ch), col)
11+
function getcolumn_chunk(chunk_contents, col::Int)
12+
return Tables.getcolumn(Tables.columns(chunk_contents), col)
13+
end
1314

14-
function DTableColumn(dtable::DTable, col::Int)
15-
column_eltype = Tables.schema(Tables.columns(dtable)).types[col]
16-
iterator_type = fetch(Dagger.spawn((ch, _col) -> typeof(iterate(__ff(ch, _col))), dtable.chunks[1], col))
15+
function DTableColumn(d::DTable, col::Int)
16+
column_eltype = Tables.schema(Tables.columns(d)).types[col]
17+
18+
iterator_type = Nothing
19+
c_idx = 1
20+
while iterator_type === Nothing && c_idx <= nchunks(d)
21+
iterator_type = fetch(Dagger.spawn(
22+
(ch, _col) -> typeof(iterate(getcolumn_chunk(ch, _col))),
23+
d.chunks[c_idx],
24+
col
25+
))
26+
c_idx += 1
27+
end
1728

1829
DTableColumn{column_eltype,iterator_type}(
19-
dtable,
20-
0,
30+
d,
2131
col,
22-
_columnnames_svector(dtable)[col],
23-
chunk_lengths(dtable),
32+
_columnnames_svector(d)[col],
33+
chunk_lengths(d),
34+
0,
2435
nothing,
2536
nothing,
2637
)
2738
end
2839

2940

30-
function getindex(dtablecolumn::DTableColumn, idx::Int)
31-
chunk_idx = 0
32-
s = 1
33-
for (i, e) in enumerate(dtablecolumn.chunk_lengths)
34-
if s <= idx < s + e
35-
chunk_idx = i
36-
break
37-
end
38-
s = s + e
39-
end
40-
chunk_idx == 0 && throw(BoundsError())
41-
offset = idx - s + 1
42-
chunk = fetch(Dagger.spawn(__ff, dtablecolumn.dtable.chunks[chunk_idx], dtablecolumn.col))
43-
44-
row, iter = iterate(Tables.rows(chunk))
45-
for _ in 1:(offset-1)
46-
row, iter = iterate(Tables.rows(chunk), iter)
47-
end
48-
Tables.getcolumn(row, dtablecolumn.col)
49-
end
41+
DTableColumn(d::DTable, col::String) =
42+
DTableColumn(d, only(indexin([col], string.(_columnnames_svector(d)))))
43+
DTableColumn(d::DTable, col::Symbol) = DTableColumn(d, string(col))
5044

51-
length(dtablecolumn::DTableColumn) = sum(dtablecolumn.chunk_lengths)
45+
length(dtc::DTableColumn) = sum(dtc.chunk_lengths)
5246

5347

54-
function pull_next_chunk(dtablecolumn::DTableColumn, chunkidx::Int)
55-
while dtablecolumn.current_iterator === nothing
56-
chunkidx += 1
57-
if chunkidx <= length(dtablecolumn.dtable.chunks)
58-
dtablecolumn.chunkstore =
59-
fetch(Dagger.spawn(__ff, dtablecolumn.dtable.chunks[chunkidx], dtablecolumn.col))
48+
function pull_next_chunk!(dtc::DTableColumn)
49+
# find first non-empty chunk
50+
while dtc._iter === nothing
51+
dtc._chunk += 1
52+
if dtc._chunk <= nchunks(dtc.dtable)
53+
dtc._chunkstore = fetch(Dagger.spawn(
54+
getcolumn_chunk,
55+
dtc.dtable.chunks[dtc._chunk],
56+
dtc.col
57+
))
6058
else
61-
return chunkidx
59+
return nothing
6260
end
63-
dtablecolumn.current_iterator = iterate(dtablecolumn.chunkstore)
61+
# iterate in case this chunk is empty
62+
dtc._iter = iterate(dtc._chunkstore)
6463
end
65-
return chunkidx
64+
return nothing
6665
end
6766

6867

69-
function iterate(dtablecolumn::DTableColumn)
70-
if length(dtablecolumn) == 0
71-
return nothing
72-
end
73-
dtablecolumn.chunkstore = nothing
74-
dtablecolumn.current_iterator = nothing
75-
chunkidx = pull_next_chunk(dtablecolumn, 0)
76-
ci = dtablecolumn.current_iterator
77-
if ci === nothing
78-
return nothing
79-
else
80-
return (ci[1], (chunkidx, ci[2]))
81-
end
68+
function iterate(dtc::DTableColumn)
69+
length(dtc) == 0 && return nothing
70+
71+
# on every iteration start reset the cache
72+
dtc._chunkstore = nothing
73+
dtc._iter = nothing
74+
dtc._chunk = 0
75+
76+
# pull the first chunk
77+
pull_next_chunk!(dtc)
78+
79+
return dtc._iter
8280
end
8381

84-
function iterate(dtablecolumn::DTableColumn, iter)
85-
(chunkidx, i) = iter
86-
cs = dtablecolumn.chunkstore
87-
ci = nothing
88-
if cs !== nothing
89-
ci = iterate(cs, i)
90-
else
91-
return nothing
92-
end
93-
dtablecolumn.current_iterator = ci
94-
chunkidx = pull_next_chunk(dtablecolumn, chunkidx)
95-
ci = dtablecolumn.current_iterator
96-
if ci === nothing
97-
return nothing
98-
else
99-
return (ci[1], (chunkidx, ci[2]))
100-
end
82+
function iterate(dtc::DTableColumn, iter)
83+
dtc._chunkstore === nothing && return nothing
84+
dtc._iter = iterate(dtc._chunkstore, iter)
85+
pull_next_chunk!(dtc)
86+
return dtc._iter
10187
end

src/table/operations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ function filter(f, d::DTable)
216216
m = TableOperations.filter(_f, _chunk)
217217
Tables.materializer(_chunk)(m)
218218
end
219-
DTable(map(c -> Dagger.spawn(chunk_wrap, c, f), d.chunks), d.tabletype)
219+
DTable(map(c -> Dagger.spawn(chunk_wrap, c, f), d.chunks), d.tabletype, d.schema)
220220
end
221221

222222

test/column.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@testset "DTableColumn" begin
2+
S = 10_000
3+
col_a = collect(1:S)
4+
col_b = rand(S)
5+
nt = (a=col_a, b=col_b)
6+
d = DTable(nt, S ÷ 10)
7+
8+
@test collect(DTableColumn(d, 1)) == col_a
9+
@test collect(DTableColumn(d, "a")) == col_a
10+
@test collect(DTableColumn(d, :a)) == col_a
11+
@test collect(DTableColumn(d, 2)) == col_b
12+
@test collect(DTableColumn(d, "b")) == col_b
13+
@test collect(DTableColumn(d, :b)) == col_b
14+
15+
d2 = filter(x -> x.a <= S / 2, d)
16+
@test collect(DTableColumn(d2, 1)) == col_a[1:Int(S / 2)]
17+
@test collect(DTableColumn(d2, 2)) == col_b[1:Int(S / 2)]
18+
19+
d2 = filter(x -> x.a >= S / 2, d)
20+
@test collect(DTableColumn(d2, 1)) == col_a[Int(S / 2):end]
21+
@test collect(DTableColumn(d2, 2)) == col_b[Int(S / 2):end]
22+
23+
d2 = filter(x -> x.a < 0, d)
24+
@test collect(DTableColumn(d2, 1)) == col_a[1:-1]
25+
@test collect(DTableColumn(d2, 2)) == col_b[1:-1]
26+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ using Distributed
1414
@testset "DTables.jl" begin
1515
include("table.jl")
1616
include("table_dataframes.jl")
17+
include("column.jl")
1718
end

0 commit comments

Comments
 (0)