Skip to content

Commit 493df94

Browse files
Merge branch 'groupby_dataset'
2 parents ab75e2c + b8dbe74 commit 493df94

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

src/groupby.jl

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ struct GroupMapping{TClass,TUC} <: AbstractGroupMapping where {TClass <: Abstrac
1212
unique_class::TUC
1313
end
1414

15-
struct GroupedDataset{TDS,TF,TGM,TM,TRF} <: AbstractDataset
15+
struct GroupedDataset{TDS<:AbstractDataset,TF,TGM,TM}
16+
ds::TDS # dataset
17+
coordname::Symbol
18+
group_fun::TF # mapping function
19+
groupmap::TGM
20+
map_fun::TM
21+
end
22+
23+
struct ReducedGroupedDataset{TDS,TF,TGM,TM,TRF} <: AbstractDataset
1624
ds::TDS # dataset
1725
coordname::Symbol
1826
group_fun::TF # mapping function
@@ -110,12 +118,12 @@ _dest_indices(j,ku,indices) = _indices_helper(j,ku,1,:,indices...)
110118
@inline _size_getindex(array,sh,n) = sh
111119

112120
#
113-
# methods with GroupedDataset as main argument
121+
# methods with ReducedGroupedDataset as main argument
114122
#
115123

116-
Base.keys(gds::GroupedDataset) = keys(gds.ds)
124+
Base.keys(gds::ReducedGroupedDataset) = keys(gds.ds)
117125

118-
function variable(gds::GroupedDataset,varname::SymbolOrString)
126+
function variable(gds::ReducedGroupedDataset,varname::SymbolOrString)
119127
v = variable(gds.ds,varname)
120128

121129
dim = findfirst(==(gds.coordname),Symbol.(dimnames(v)))
@@ -355,7 +363,17 @@ function groupby(v::AbstractVariable,(coordname,group_fun)::Pair{<:SymbolOrStrin
355363
dim = findfirst(==(Symbol(coordname)),Symbol.(dimnames(v)))
356364
map_fun = identity
357365
groupmap = GroupMapping(class,unique_class)
358-
return GroupedVariable(v,coordname,group_fun,groupmap,dim,map_fun)
366+
return GroupedVariable(v,Symbol(coordname),group_fun,groupmap,dim,map_fun)
367+
end
368+
369+
370+
function groupby(ds::AbstractDataset,(coordname,group_fun)::Pair{<:SymbolOrString,TF}) where TF
371+
c = ds[String(coordname)][:]
372+
class = group_fun.(c)
373+
unique_class = sort(unique(class))
374+
map_fun = identity
375+
groupmap = GroupMapping(class,unique_class)
376+
return GroupedDataset(ds,Symbol(coordname),group_fun,groupmap,map_fun)
359377
end
360378

361379
"""
@@ -376,13 +394,22 @@ end
376394

377395
function ReducedGroupedVariable(gv::GroupedVariable,reduce_fun)
378396
T = eltype(gv.v)
379-
#@show T, reduce_fun
380-
#@show Base.return_types(reduce_fun, (Vector{T},))
381-
397+
@debug "inference " T reduce_fun Base.return_types(reduce_fun, (Vector{T},))
382398
N = ndims(gv.v)
383399
ReducedGroupedVariable{T,N,typeof(gv),typeof(reduce_fun)}(gv,reduce_fun)
384400
end
385401

402+
function ReducedGroupedDataset(gds::GroupedDataset,reduce_fun)
403+
return ReducedGroupedDataset(
404+
gds.ds,
405+
gds.coordname,
406+
gds.group_fun,
407+
gds.groupmap,
408+
gds.map_fun,
409+
reduce_fun,
410+
)
411+
end
412+
386413
"""
387414
gr = reduce(f,gv::GroupedVariable)
388415
@@ -392,19 +419,15 @@ of `gv`) and `d` is an integer of the dimension overwhich one need to reduce
392419
`x`.
393420
"""
394421
Base.reduce(f,gv::GroupedVariable) = ReducedGroupedVariable(gv,f)
422+
Base.reduce(f,gds::GroupedDataset) = ReducedGroupedDataset(gds,f)
395423

396424
for fun in (:maximum, :mean, :median, :minimum, :std, :sum, :var)
397425
@eval $fun(gv::GroupedVariable) = reduce($fun,gv)
426+
@eval $fun(gds::GroupedDataset) = reduce($fun,gds)
398427
end
399428

400429
# methods with ReducedGroupedVariable as main argument
401430

402-
function Base.show(io::IO,::MIME"text/plain",gv::ReducedGroupedVariable)
403-
println(
404-
io,join(string.(size(gv)),'×')," array after reducing using ",
405-
"$(gv.reduce_fun)")
406-
end
407-
408431
Base.ndims(gr::ReducedGroupedVariable) = ndims(gr.gv.v)
409432
Base.size(gr::ReducedGroupedVariable) = ntuple(ndims(gr)) do i
410433
if i == gr.gv.dim
@@ -531,7 +554,7 @@ function dataset(gr::ReducedGroupedVariable)
531554
gv = gr.gv
532555
ds = dataset(gv.v)
533556

534-
return GroupedDataset(
557+
return ReducedGroupedDataset(
535558
ds,gv.coordname,gv.group_fun,
536559
gv.groupmap,
537560
gv.map_fun,

test/test_groupby.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ gd = groupby(ds[:data],:time => Dates.Month)
118118
month_sum = sum(gd);
119119
@test month_sum[:,:,:] == d_sum
120120

121+
# group dataset
122+
gds = sum(groupby(ds,:time => Dates.Month))
123+
@test gds["data"][:,:,:] == d_sum
124+
@test gds["lon"][:] == ds["lon"][:]
125+
@test gds["lat"][:] == ds["lat"][:]
121126

122127

123128
gr = month_sum
@@ -203,7 +208,8 @@ gr2 = mean(@groupby(ds["data2"],Dates.Month(time)))
203208
@test gds["lon"][:] == 1:size(data,1)
204209
io = IOBuffer()
205210
show(io,"text/plain",gr)
206-
@test occursin("array", String(take!(io)))
211+
#@test occursin("array", String(take!(io)))
212+
@test occursin("Dimensions", String(take!(io)))
207213

208214
io = IOBuffer()
209215
show(io,"text/plain",gv)

0 commit comments

Comments
 (0)