Skip to content

Commit bfef320

Browse files
keep attributes for reduced grouped variables
1 parent 23a6735 commit bfef320

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

src/groupby.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ end
4444
struct ReducedGroupedVariable{T,N,TGV,TF} <: AbstractVariable{T,N}
4545
gv::TGV
4646
reduce_fun::TF
47+
_attrib::OrderedDict{String,Any}
4748
end
4849

4950

@@ -141,6 +142,11 @@ function variable(gds::ReducedGroupedDataset,varname::SymbolOrString)
141142
end
142143
end
143144

145+
146+
attribnames(gds::ReducedGroupedDataset) = attribnames(gds.ds)
147+
attrib(gds::ReducedGroupedDataset,attribname::SymbolOrString) =
148+
attrib(gds.ds,attribname)
149+
144150
#
145151
# methods with GroupedVariable as main argument
146152
#
@@ -396,7 +402,8 @@ function ReducedGroupedVariable(gv::GroupedVariable,reduce_fun)
396402
T = eltype(gv.v)
397403
@debug "inference " T reduce_fun Base.return_types(reduce_fun, (Vector{T},))
398404
N = ndims(gv.v)
399-
ReducedGroupedVariable{T,N,typeof(gv),typeof(reduce_fun)}(gv,reduce_fun)
405+
_attrib = OrderedDict(gv.v.attrib)
406+
ReducedGroupedVariable{T,N,typeof(gv),typeof(reduce_fun)}(gv,reduce_fun,_attrib)
400407
end
401408

402409
function ReducedGroupedDataset(gds::GroupedDataset,reduce_fun)
@@ -440,6 +447,12 @@ end
440447
dimnames(gr::ReducedGroupedVariable) = dimnames(gr.gv.v)
441448
name(gr::ReducedGroupedVariable) = name(gr.gv.v)
442449

450+
451+
attribnames(gr::ReducedGroupedVariable) = collect(keys(gr._attrib))
452+
attrib(gr::ReducedGroupedVariable,attribname::SymbolOrString) = gr._attrib[attribname]
453+
defAttrib(gr::ReducedGroupedVariable,attribname::SymbolOrString,value) =
454+
gr._attrib[attribname] = value
455+
443456
struct ReducedGroupedVariableStyle <: BroadcastStyle end
444457
Base.BroadcastStyle(::Type{<:ReducedGroupedVariable}) = ReducedGroupedVariableStyle()
445458

test/test_groupby.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@ using Statistics
66
using CommonDataModel
77
using CommonDataModel:
88
@groupby,
9-
name,
109
GroupedVariable,
10+
MemoryDataset,
1111
ReducedGroupedVariable,
1212
_array_selectdim_indices,
1313
_dest_indices,
1414
_dim_after_getindex,
1515
_indices,
16-
groupby
16+
dataset,
17+
defVar,
18+
groupby,
19+
name,
20+
variable
1721

1822

1923
#include("memory_dataset.jl");
@@ -53,13 +57,14 @@ data3 .= data
5357
data3[1,1,:] .= missing
5458
data3[1,2,1] = missing
5559

56-
TDS(fname,"c") do ds
60+
TDS(fname,"c",attrib = ["title" => "test"]) do ds
5761
defVar(ds,"lon",1:size(data,1),("lon",))
5862
defVar(ds,"lat",1:size(data,2),("lat",))
5963
defVar(ds,"time",time,("time",))
60-
defVar(ds,"data",data,("lon","lat","time"))
64+
defVar(ds,"data",data,("lon","lat","time"),attrib = ["foo" => "bar"])
6165
defVar(ds,"data2",data .+ 1,("lon","lat","time"))
6266
defVar(ds,"data3",data3,("lon","lat","time"))
67+
defVar(ds,"data4",data,("lon","lat","time"),attrib = ["scale_factor" => 2])
6368
end
6469

6570
ds = TDS(fname)
@@ -108,6 +113,9 @@ gd = groupby(ds[:data],:time => Dates.Month);
108113
d_sum = cat([sum(ds[varname][:,:,:][:,:,findall(Dates.month.(ds[coordname][:]) .== m)],dims=3)
109114
for m in 1:12]...,dims=3)
110115

116+
d_mean = cat([mean(ds[varname][:,:,:][:,:,findall(Dates.month.(ds[coordname][:]) .== m)],dims=3)
117+
for m in 1:12]...,dims=3)
118+
111119

112120
gd = groupby(ds["data"],"time" => Dates.Month)
113121
month_sum = sum(gd);
@@ -118,11 +126,25 @@ gd = groupby(ds[:data],:time => Dates.Month)
118126
month_sum = sum(gd);
119127
@test month_sum[:,:,:] == d_sum
120128

121-
# group dataset
122-
gds = sum(groupby(ds,:time => Dates.Month))
123-
@test gds["data"][:,:,:] == d_sum
129+
# group dataset function
130+
gds = mean(groupby(ds,:time => Dates.Month))
131+
@test gds["data"][:,:,:] == d_mean
124132
@test gds["lon"][:] == ds["lon"][:]
125133
@test gds["lat"][:] == ds["lat"][:]
134+
@test gds["data4"][:,:,:] == d_mean
135+
@test gds.attrib["title"] == "test"
136+
@test gds["data"].attrib["foo"] == "bar"
137+
@test collect(keys(gds.attrib)) == ["title"]
138+
@test collect(keys(gds["data"].attrib)) == ["foo"]
139+
140+
# group dataset macro
141+
gds = mean(@groupby(ds,Dates.Month(time)))
142+
@test gds["data"][:,:,:] == d_mean
143+
@test gds["lon"][:] == ds["lon"][:]
144+
@test gds["lat"][:] == ds["lat"][:]
145+
146+
gr = mean(groupby(ds["data4"],:time => Dates.Month))
147+
@test gr[:,:,:] == d_mean
126148

127149

128150
gr = month_sum

0 commit comments

Comments
 (0)