Skip to content

Commit 2bc883f

Browse files
authored
Add transform and combine from DataFrames (#22)
1 parent d1eed4e commit 2bc883f

File tree

5 files changed

+204
-82
lines changed

5 files changed

+204
-82
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
99
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1010
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
1111
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1213
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
1314
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1415

src/DTables.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using DataFrames:
1616
make_pair_concrete
1717
using InvertedIndices: Not
1818
using SentinelArrays: ChainedVector
19+
using Statistics: mean
1920
using TableOperations: TableOperations
2021
using Tables:
2122
columnindex,
@@ -54,7 +55,7 @@ import Base:
5455
import DataAPI: leftjoin, ncol, nrow, innerjoin
5556
import Tables:
5657
columnaccess, columnnames, columns, getcolumn, istable, partitions, rowaccess, rows, schema
57-
import DataFrames: broadcast_pair, select, index
58+
import DataFrames: broadcast_pair, combine, groupby, select, index, transform
5859

5960
############################################################################################
6061
# Export
@@ -64,6 +65,7 @@ export All,
6465
AsTable,
6566
Between,
6667
ByRow,
68+
combine,
6769
Cols,
6870
DTable,
6971
DTableColumn,
@@ -73,8 +75,10 @@ export All,
7375
Not,
7476
nrow,
7577
select,
78+
groupby,
7679
tabletype,
7780
tabletype!,
81+
transform,
7882
trim,
7983
trim!
8084
############################################################################################

src/operations/dataframes_interface.jl

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,16 @@ function _manipulate(df::DTable, normalized_cs::Vector{Any}, copycols::Bool, kee
5454
# STAGE 1: Spawning full column thunks - also multicolumn when needed (except identity)
5555
# These get saved later and used in last stages.
5656
#########
57-
colresults = Dict{Int,Any}()
58-
for (i, (colidx, (f, _))) in enumerate(normalized_cs)
59-
if !(colidx isa AsTable) && !(f isa ByRow) && f != identity
60-
if length(colidx) > 0
61-
cs = DTableColumn.(Ref(df), [colidx...])
62-
colresults[i] = Dagger.@spawn f(cs...)
57+
normalized_cs_results = Dict{Int,Dagger.EagerThunk}()
58+
for (idx, (column_index, (fun, result_column_symbol))) in enumerate(normalized_cs)
59+
if (!(column_index isa AsTable) && !(fun isa ByRow) && fun != identity)
60+
if length(column_index) > 0
61+
normalized_cs_results[idx] = Dagger.@spawn fun(
62+
DTableColumn.(Ref(df), [column_index...])...
63+
)
6364
else
64-
colresults[i] = Dagger.@spawn f() # case of select(d, [] => fun)
65+
# case of select(d, [] => fun) where there are no input columns
66+
normalized_cs_results[idx] = Dagger.@spawn fun()
6567
end
6668
end
6769
end
@@ -71,51 +73,78 @@ function _manipulate(df::DTable, normalized_cs::Vector{Any}, copycols::Bool, kee
7173
# These will be just injected as values in the mapping, because it's a vector full of these values
7274
#########
7375

74-
colresults = Dict{Int,Any}(
75-
k => fetch(Dagger.spawn(length, v)) == 1 ? fetch(v) : v for (k, v) in colresults
76+
mappable_part_of_normalized_cs = filter(
77+
x -> !haskey(normalized_cs_results, x[1]), collect(enumerate(normalized_cs))
7678
)
7779

78-
mapmask = [
79-
haskey(colresults, x) && colresults[x] isa Dagger.EagerThunk for
80-
(x, _) in enumerate(normalized_cs)
81-
]
82-
83-
mappable_part_of_normalized_cs = filter(x -> !mapmask[x[1]], collect(enumerate(normalized_cs)))
84-
8580
#########
8681
# STAGE 3: Mapping function (need to ensure this is compiled only once)
8782
# It's awful right now, but it covers all cases
8883
# Essentially we skip all the non-mappable stuff here
8984
#########
9085

91-
rd = map(x -> select_rowfunction(x, mappable_part_of_normalized_cs, colresults), df)
86+
has_any_mappable = length(mappable_part_of_normalized_cs) > 0
87+
88+
rd = if has_any_mappable || keeprows
89+
map(x -> select_rowfunction(x, mappable_part_of_normalized_cs), df)
90+
else
91+
nothing # in case there's nothing mappable we just go ahead with an empty dtable (just nothing)
92+
end
9293

9394
#########
9495
# STAGE 4: Preping for last stage - getting all the full column thunks with not 1 lengths
9596
#########
96-
cpcolresults = Dict{Int,Any}()
9797

98-
for (k, v) in colresults
99-
if v isa Dagger.EagerThunk
100-
cpcolresults[k] = v
101-
end
98+
fullcolumn_ops_result_lengths = Int[
99+
fetch(Dagger.spawn(length, v)) for v in values(normalized_cs_results)
100+
]
101+
102+
collength_to_compare_against = if has_any_mappable || keeprows
103+
length(df)
104+
else
105+
maximum(fullcolumn_ops_result_lengths)
102106
end
103107

104-
for (_, v) in colresults
105-
if v isa Dagger.EagerThunk
106-
if fetch(Dagger.spawn(length, v)) != length(df)
107-
throw("result column is not the size of the table")
108-
end
109-
end
108+
if !all(map(x -> x == 1 || x == collength_to_compare_against, fullcolumn_ops_result_lengths))
109+
throw(ArgumentError("New columns must have the same length as old columns"))
110110
end
111+
111112
#########
112113
# STAGE 5: Fill columns - meaning the previously omitted full column tasks
113114
# will be now merged into the final DTable
114115
#########
115116

116-
rd = fillcolumns(rd, cpcolresults, normalized_cs, chunk_lengths(df))
117+
new_chunk_lengths = if has_any_mappable || keeprows
118+
chunk_lengths(df)
119+
elseif maximum(fullcolumn_ops_result_lengths) == 1
120+
z = zeros(Int, nchunks(df))
121+
z[1] = 1
122+
z
123+
else
124+
b = maximum(fullcolumn_ops_result_lengths)
125+
a = zeros(Int, nchunks(df))
126+
avg_chunk_length = floor(Int, mean(chunk_lengths(df)))
127+
for (i, c) in enumerate(chunk_lengths(df))
128+
if b >= c
129+
a[i] += c
130+
b -= c
131+
else
132+
a[i] += b
133+
b = 0
134+
end
135+
end
136+
while b > 0
137+
bm = min(b, avg_chunk_length)
138+
push!(a, bm)
139+
b -= bm
140+
end
141+
a
142+
end
117143

118-
return rd
144+
rd2 = fillcolumns(
145+
rd, normalized_cs_results, normalized_cs, new_chunk_lengths, fullcolumn_ops_result_lengths
146+
)
147+
return rd2
119148
end
120149

121150
"""
@@ -147,3 +176,23 @@ function select(
147176
renamecols=renamecols,
148177
)
149178
end
179+
180+
function transform(
181+
df::DTable,
182+
@nospecialize(args...);
183+
copycols::Bool=true,
184+
renamecols::Bool=true,
185+
threads::Bool=true,
186+
)
187+
return select(df, :, args...; copycols=copycols, renamecols=renamecols, threads=threads)
188+
end
189+
190+
function combine(df::DTable, @nospecialize(args...); renamecols::Bool=true, threads::Bool=true)
191+
return manipulate(
192+
df,
193+
map(x -> broadcast_pair(df, x), args)...;
194+
copycols=true,
195+
keeprows=false,
196+
renamecols=renamecols,
197+
)
198+
end
Lines changed: 92 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
function select_rowfunction(row, mappable_part_of_normalized_cs, colresults)
2+
function select_rowfunction(row, mappable_part_of_normalized_cs)
33
_cs = [
44
begin
55
kk = result_colname === AsTable ? Symbol("AsTable$(i)") : result_colname
@@ -21,8 +21,6 @@ function select_rowfunction(row, mappable_part_of_normalized_cs, colresults)
2121
f.fun(args)
2222
elseif f == identity
2323
args
24-
elseif length(colresults[i]) == 1
25-
colresults[i]
2624
else
2725
throw(ErrorException("Weird unhandled stuff"))
2826
end
@@ -33,62 +31,106 @@ function select_rowfunction(row, mappable_part_of_normalized_cs, colresults)
3331
return (; _cs...)
3432
end
3533

36-
function fillcolumns(
37-
dt::DTable, ics::Dict{Int,Any}, normalized_cs, chunk_lengths_of_original_dt::Vector{Int}
34+
function fillcolumn(
35+
chunk,
36+
csymbols::Union{Vector{DataType},Vector{Symbol},Vector{Union{DataType,Symbol}}},
37+
colfragments::Union{Vector{Dagger.EagerThunk},Vector{Any}},
38+
expected_chunk_length::Int,
39+
normalized_cs::Vector{Any},
3840
)
39-
col_keys_indices = collect(keys(ics))::Vector{Int}
40-
col_vecs = map(x -> ics[x], col_keys_indices)::Union{Vector{Any},Vector{Dagger.EagerThunk}}
41-
42-
f =
43-
(ch, csymbols, colfragments) -> begin
44-
col_vecs_fetched = fetch.(colfragments)
45-
colnames = Vector{Symbol}()
46-
cols = Vector{Any}()
47-
last_astable = 0
41+
col_vecs_fetched = fetch.(colfragments)
42+
colnames = Vector{Symbol}()
43+
cols = Vector{Any}()
44+
last_astable = 0
4845

49-
for (idx, (_, (_, sym))) in enumerate(normalized_cs)
50-
if sym !== AsTable
51-
col = if sym in csymbols
52-
index = something(indexin(csymbols, [sym])...)
53-
col_vecs_fetched[index]
54-
else
55-
getcolumn(ch, sym)
56-
end
57-
push!(colnames, sym)
58-
push!(cols, col)
59-
elseif sym === AsTable
60-
i = findfirst(x -> x === AsTable, csymbols[(last_astable + 1):end])
61-
if i === nothing
62-
c = getcolumn(ch, Symbol("AsTable$(idx)"))
63-
else
64-
last_astable = i
65-
c = col_vecs_fetched[i]
66-
end
67-
68-
push!.(Ref(colnames), columnnames(columns(c)))
69-
push!.(Ref(cols), getcolumn.(Ref(columns(c)), columnnames(columns(c))))
46+
for (idx, (_, (_, sym))) in enumerate(normalized_cs)
47+
if sym !== AsTable
48+
col = if sym in csymbols
49+
index = findfirst(x -> x === sym, csymbols)
50+
if col_vecs_fetched[index] isa AbstractVector
51+
col_vecs_fetched[index]
52+
else
53+
repeat([col_vecs_fetched[index]], expected_chunk_length)
54+
end
55+
else
56+
getcolumn(chunk, sym)
57+
end
58+
push!(colnames, sym)
59+
push!(cols, col)
60+
elseif sym === AsTable
61+
i = findfirst(x -> x === AsTable, csymbols[(last_astable + 1):end])
62+
c = if i === nothing
63+
getcolumn(chunk, Symbol("AsTable$(idx)"))
64+
else
65+
last_astable = i
66+
if col_vecs_fetched[i] isa AbstractVector
67+
col_vecs_fetched[i]
7068
else
71-
throw(ErrorException("something is off"))
69+
repeat([col_vecs_fetched[i]], expected_chunk_length)
7270
end
7371
end
74-
materializer(ch)(
75-
merge(NamedTuple(), (; [e[1] => e[2] for e in zip(colnames, cols)]...))
76-
)
72+
73+
push!.(Ref(colnames), columnnames(columns(c)))
74+
push!.(Ref(cols), getcolumn.(Ref(columns(c)), columnnames(columns(c))))
75+
else
76+
throw(ErrorException("something is off"))
7777
end
78+
end
79+
return materializer(chunk)(
80+
merge(NamedTuple(), (; [e[1] => e[2] for e in zip(colnames, cols)]...))
81+
)
82+
end
83+
84+
function fillcolumns(
85+
dt::Union{Nothing,DTable},
86+
normalized_cs_results::Dict{Int,Dagger.EagerThunk},
87+
normalized_cs::Vector{Any},
88+
new_chunk_lengths::Vector{Int},
89+
fullcolumn_ops_result_lengths::Vector{Int},
90+
)
91+
fullcolumn_ops_indices_in_normalized_cs = collect(keys(normalized_cs_results))::Vector{Int}
92+
fullcolumn_ops_results_ordered = map(
93+
x -> normalized_cs_results[x], fullcolumn_ops_indices_in_normalized_cs
94+
)::Union{Vector{Any},Vector{Dagger.EagerThunk}}
7895

7996
colfragment = (column, s, e) -> Dagger.@spawn getindex(column, s:e)
80-
clenghts = chunk_lengths_of_original_dt
81-
result_column_symbols = getindex.(Ref(map(x -> x[2][2], normalized_cs)), col_keys_indices)
97+
result_column_symbols =
98+
getindex.(Ref(map(x -> x[2][2], normalized_cs)), fullcolumn_ops_indices_in_normalized_cs)
8299

83-
chunks = [
84-
begin
85-
cfrags = [
86-
colfragment(column, 1 + sum(clenghts[1:(i - 1)]), sum(clenghts[1:i])) for
87-
column in col_vecs
88-
]
89-
Dagger.@spawn f(ch, result_column_symbols, cfrags)
90-
end for (i, ch) in enumerate(dt.chunks)
100+
dtchunks = if dt === nothing
101+
[Dagger.spawn(() -> nothing) for _ in 1:length(new_chunk_lengths)]
102+
else
103+
dt.chunks
104+
end
105+
dtchunks_filled = [
106+
x <= length(dtchunks) ? dtchunks[x] : Dagger.spawn(() -> nothing) for
107+
x in 1:length(new_chunk_lengths)
91108
]
92109

93-
return DTable(chunks, dt.tabletype)
110+
chunks = Dagger.EagerThunk[
111+
Dagger.spawn(
112+
fillcolumn,
113+
chunk,
114+
result_column_symbols,
115+
[
116+
if len > 1
117+
colfragment(
118+
column, 1 + sum(new_chunk_lengths[1:(i - 1)]), sum(new_chunk_lengths[1:i])
119+
)
120+
else
121+
column
122+
end for
123+
(column, len) in zip(fullcolumn_ops_results_ordered, fullcolumn_ops_result_lengths)
124+
],
125+
new_chunk_length,
126+
normalized_cs,
127+
) for
128+
(i, (chunk, new_chunk_length)) in enumerate(zip(dtchunks_filled, new_chunk_lengths)) if
129+
new_chunk_length > 0
130+
]
131+
if dt === nothing
132+
return DTable(chunks, nothing)
133+
else
134+
return DTable(chunks, dt.tabletype)
135+
end
94136
end

0 commit comments

Comments
 (0)