Skip to content

Commit 7ab2a92

Browse files
authored
Merge pull request #561 from JuliaDataCubes/fg/engineyax2
Allow Group by another YAXArray
2 parents 8bd94a0 + 52f2bea commit 7ab2a92

File tree

6 files changed

+207
-30
lines changed

6 files changed

+207
-30
lines changed

src/DAT/broadcast.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@ struct XStyle <: Broadcast.BroadcastStyle end
22
Base.BroadcastStyle(::Broadcast.AbstractArrayStyle, ::XStyle) = XStyle()
33
Base.BroadcastStyle(::XStyle, ::Broadcast.AbstractArrayStyle) = XStyle()
44
Base.BroadcastStyle(::Type{<:YAXArray}) = XStyle()
5-
to_yax(x::Number) = YAXArray((), fill(x))
5+
Base.BroadcastStyle(::Type{<:DimWindowArray}) = XStyle()
6+
to_yax(x::Number) = YAXArray((),fill(x))
67
to_yax(x::DD.AbstractDimArray) = x
8+
to_yax(x::DimWindowArray) = x
9+
10+
Base.broadcastable(d::DimWindowArray) = d
11+
Base.broadcastable(d::YAXArray) = d
12+
713
function Base.broadcasted(::XStyle, f, args...)
814
return Broadcast.Broadcasted{XStyle}(f, args)
915
end

src/DAT/counter.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import OnlineStats
2+
import DiskArrayEngine as DAE
3+
4+
function counter(yax, expected_values=nothing)
5+
st = if expected_values isa AbstractUnitRange{<:Int}
6+
compoffs = DAE.KeyConvertDicts.AddConst(Val(1-first(expected_values)))
7+
mydicttype = DAE.KeyConvertDicts.KeyDictType(eltype(expected_values),Int,compoffs, inv(compoffs),length(expected_values))
8+
OnlineStats.CountMap{eltype(yax),mydicttype}
9+
else
10+
OnlineStats.CountMap{eltype(yax),Dict{eltype(yax),Int}}
11+
end
12+
dimargs = ntuple(i->i=>nothing,ndims(yax))
13+
DAE.aggregate_diskarray(yax.data,st,dimargs)
14+
end

src/DAT/resample.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function xresample(yax::YAXArray;to=nothing,method=Linear(),outtype=Float32)
1111
conv = map(newdims) do d
1212
dold = DD.dims(yax.axes,d)
1313
dold === nothing && return nothing
14+
approxequal(dold,d) && return nothing
1415
idim = DD.dimnum(yax.axes,d)
1516
idim=>(valval(dold),valval(d))
1617
end

src/DAT/xmap.jl

Lines changed: 158 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@ module Xmap
22
import DimensionalData: rebuild, dims, DimArrayOrStack, Dimension, basedims,
33
DimTuple, _group_indices, OpaqueArray, otherdims
44
import DimensionalData as DD
5-
using DiskArrays: find_subranges_sorted, chunktype_from_chunksizes, GridChunks
5+
using DiskArrays: find_subranges_sorted, chunktype_from_chunksizes, GridChunks,isdisk
66
using ..YAXArrays
77
import ..Cubes: YAXArray
88
import DiskArrayEngine as DAE
99
import IntervalSets: Interval
10+
import DiskArrayEngine.compute
1011

1112
include("resample.jl")
1213

13-
export windows, xmap, Whole, xmap, XOutput, compute_to_zarr, xresample, MovingIntervals, XFunction,
14+
export windows, xmap, Whole, xmap, XOutput, compute_to_zarr, xresample, MovingIntervals, XFunction, , compute
1415

1516
struct Whole <: DD.AbstractBins end
1617
function DD._group_indices(dim::DD.Dimension, ::Whole; labels=nothing)
@@ -54,6 +55,50 @@ time_mean = xmap(mean, w, inplace=false)
5455
(a,b) = (a,(b,))
5556
windows(A::DimArrayOrStack) = DimWindowArray(A,DD.dims(A),map(d->1:length(d),DD.dims(A)),DD.dims(A))
5657
windows(A::DimArrayOrStack, x) = windows(A, dims(x))
58+
#Method to group by another DimArray defined by groups. This array might be very
59+
#large, so we will not be able to determine the groups in all cases and might
60+
#introduce an axis of unknown length
61+
struct UnknownValues{T}
62+
eltype::Type{T}
63+
end
64+
struct GroupIndices{I}
65+
indices::I
66+
dims_destroyed
67+
dims_new
68+
end
69+
function windows(A::DimArrayOrStack, groups::DD.AbstractDimArray;groupname=:group, expected_groups=nothing)
70+
groupdims = DD.dims(groups)
71+
arraydims = DD.dims(A)
72+
expanddims = otherdims(groupdims, arraydims)
73+
groupdimvals = if !isnothing(expected_groups)
74+
expected_groups
75+
elseif isdisk(DD.data(groups))
76+
UnknownValues(eltype(groups))
77+
else
78+
ug = collect(skipmissing(unique(groups)))
79+
mi,ma = extrema(ug)
80+
if all(isinteger,ug) && length(ug)/ (ma-mi+1) > 0.5
81+
Int(mi):Int(ma)
82+
else
83+
sort!(ug)
84+
end
85+
end
86+
array_indices = map(DD.dims(A)) do d
87+
DD.rebuild(d,1:length(d))
88+
end
89+
newindices = GroupIndices(DD.data(groups),groupdims, DD.Dim{groupname}(groupdimvals))
90+
gd = DD.dims(array_indices,first(groupdims))
91+
newind = DD.rebuild(gd,newindices)
92+
array_indices_view = DD.setdims(array_indices, newind)
93+
array_indices_pure = map(DD.val, DD.otherdims(array_indices_view,Base.tail(groupdims)))
94+
dim_orig = DD.dims(A)
95+
newdim_firstname = DD.rebuild(first(groupdims),groupdimvals)
96+
newdims = DD.setdims(dim_orig, newdim_firstname)
97+
newdims = DD.otherdims(newdims, Base.tail(groupdims))
98+
irep = DD.dimnum(newdims, first(groupdims))
99+
newdims = Base.setindex(newdims, DD.Dim{groupname}(groupdimvals),irep)
100+
return DimWindowArray(A, newdims, array_indices_pure, dim_orig)
101+
end
57102
windows(A::DimArrayOrStack, dimfuncs::Dimension...) = windows(A, dimfuncs)
58103
function windows(
59104
A::DimArrayOrStack, p1::Pair{<:Any,<:Base.Callable}, ps::Pair{<:Any,<:Base.Callable}...;
@@ -63,6 +108,7 @@ function windows(
63108
end
64109
return windows(A, dims)
65110
end
111+
66112
function windows(A::DimArrayOrStack, dimfuncs::DimTuple)
67113
length(otherdims(dimfuncs, dims(A))) > 0 &&
68114
DD.Dimensions._extradimserror(otherdims(dimfuncs, dims(A)))
@@ -85,10 +131,6 @@ function windows(A::DimArrayOrStack, dimfuncs::DimTuple)
85131
array_indices_pure = map(DD.val, array_indices_view)
86132
dim_orig = DD.dims(A)
87133
newdims = DD.setdims(dim_orig, group_dims)
88-
N = ndims(A)
89-
indt = map(eltype,array_indices_pure)
90-
et = Base.promote_op(getindex,typeof(A),indt...)
91-
#etdim = mapreduce(ndims ∘ eltype, +, array_indices_pure)
92134
return DimWindowArray(A, newdims, array_indices_pure, dim_orig)
93135
end
94136

@@ -200,8 +242,10 @@ struct DimWindowArray{A,D,I,DO}
200242
indices::I
201243
dim_orig::DO
202244
end
203-
Base.size(a::DimWindowArray) = length.(a.indices)
204-
Base.getindex(a::DimWindowArray, i::Int...) = a.data[map(getindex,a.indices,i)...]
245+
Base.size(a::DimWindowArray) = length.(a.dims)
246+
index_group(a,i) = a[i]
247+
index_group(a::GroupIndices,i) = findall(isequal(i),a.indices)
248+
Base.getindex(a::DimWindowArray, i::Int...) = a.data.data[map(index_group,a.indices,i)...]
205249
DD.dims(a::DimWindowArray) = a.dims
206250
to_windowarray(d::DimWindowArray) = d
207251
to_windowarray(d) = windows(d)
@@ -211,13 +255,14 @@ function Base.show(io::IO, dw::DimWindowArray)
211255
end
212256

213257

214-
struct XOutput{D<:Tuple{Vararg{DD.Dimension}},T}
258+
struct XOutput{D<:Tuple{Vararg{DD.Dimension}},R,T}
215259
outaxes::D
260+
destroyaxes::R
216261
outtype::T
217262
properties
218263
end
219-
function XOutput(outaxes::DD.Dimension...; outtype=1, properties=Dict())
220-
XOutput(outaxes, outtype,properties)
264+
function XOutput(outaxes...; outtype=1,properties=Dict(),destroyaxes=())
265+
XOutput(outaxes, destroyaxes, outtype,properties)
221266
end
222267

223268
_step(x::AbstractArray{<:Number}) = length(x) > 1 ? (last(x)-first(x))/(length(x)-1) : zero(eltype(x))
@@ -259,13 +304,62 @@ dataeltype(y::DimWindowArray) = eltype(y.data.data)
259304

260305
tupelize(x) = (x,)
261306
tupelize(x::Tuple) = x
307+
"""
308+
_groupby_xmap(f,winars...;output,inplace)
309+
310+
Function to handle groupby operations in `xmap`. It assumes that the only input array
311+
is a DimWindowArray where one of the dimensions is a GroupIndices dimension.
312+
"""
313+
function _groupby_xmap(f,ars...;output,inplace)
314+
315+
@assert length(ars) == 1
316+
g = only(ars)
317+
inds = findall(i->isa(i,Xmap.GroupIndices), g.indices)
318+
#For now we allow only a single group array
319+
@assert length(inds) == 1
320+
igroup = only(inds)
321+
322+
preproc, groupconv = (identity, identity)
323+
_f = isa(f,XFunction) ? f.f : f
324+
newf = DAE.disk_onlinestat(_f,preproc,groupconv)
325+
326+
outputs = XOutput(g.dims[igroup],destroyaxes=DD.otherdims(g.dim_orig,g.dims))
327+
328+
groupar = YAXArray(g.indices[igroup].dims_destroyed, g.indices[igroup].indices)
329+
xmap(newf, g.data, groupar, output=outputs)
330+
end
331+
332+
333+
"""
334+
xmap(f, ar...; output = nothing, inplace = nothing)
335+
336+
Maps a function `f` over an array of `ar` of type `YAXArray` or `DimWindowArray`.
262337
338+
`xmap` requires the specification of a type for the output of `f`, with a default type which is 1 indicating
339+
that the data type should be equal to the element type of the first input array. `output` must be a list of `XOutput` objects, where each contains a tuple of axes under which the results are stored and the type of the values stored. If `inplace` is `true`, then the original values are replaced in a place. `xmap` returns one or more objects of type `YAXArray` or `DimWindowArray` containing a view over the data passed to `f` by `overlaying` the outputs over the original data arrays. If reduction functions are specified, then the `xmap` outputs replace the original original data array with the reduced values. During the execution of `xmap`, the everything except `f` itself is compiled just once. Specifying `f` as an object of type `XFunction` waits until the actual function is called before compiling it.
340+
341+
xmap will return a lazy representation of the resulting array.
263342
264343
function xmap(f, ars::Union{YAXArrays.Cubes.YAXArray,DimWindowArray}...;
265344
output=XOutput(),
266345
inplace=default_inplace(f),
267346
function_args=(),
268347
function_kwargs=(;))
348+
* `output::Vector{XOutput}`: specifies the output arrays. Each XOutput object contains a tuple of axes (or symbols) to store the result and the element type of the output arrays.
349+
* `inplace::Bool`: if `true` the function `f` operates in-place so that pre-allocated output buffers will be passed to the function as
350+
351+
352+
353+
**Examples.**
354+
355+
# Simple mapping
356+
x = xmap(+, a, b)
357+
358+
359+
"""
360+
function xmap(f, ars::Union{YAXArrays.Cubes.YAXArray,DimWindowArray}...; args=(), kwargs=(;), output=nothing, inplace=nothing, function_args=(), function_kwargs=(;))
361+
output === nothing && (output = default_output(f))
362+
inplace === nothing && (inplace = default_inplace(f))
269363
alldims = mapreduce(approxunion!,ars,init=[]) do ar
270364
DD.dims(ar)
271365
end
@@ -277,6 +371,14 @@ function xmap(f, ars::Union{YAXArrays.Cubes.YAXArray,DimWindowArray}...;
277371
throw(ArgumentError("Duplicated dimensions with different values"))
278372
end
279373

374+
winars = map(to_windowarray,ars)
375+
#Check for any input that contains an array groupby
376+
is_groupby = any(winars) do a
377+
any(Base.Fix2(isa,GroupIndices),a.indices)
378+
end
379+
380+
is_groupby && return _groupby_xmap(f,winars...;output,inplace)
381+
280382
#Create outspecs
281383
output = tupelize(output)
282384

@@ -286,12 +388,25 @@ function xmap(f, ars::Union{YAXArrays.Cubes.YAXArray,DimWindowArray}...;
286388

287389
allinandoutdims = (unique(DD.basedims((alldims..., alloutdims...)))...,)
288390

289-
290391
outaxinfo = map(output) do o
291392
outaxes = o.outaxes
393+
destroydims = o.destroyaxes
292394
addaxes = DD.otherdims(alldims, DD.basedims(outaxes))
293395
outwindows = map(i->[Base.OneTo(length(i))],outaxes)
294-
extrawindows = Base.OneTo.(length.(addaxes))
396+
extrawindows = map(addaxes) do outax
397+
if isnothing(DD.dims(destroydims, outax))
398+
Base.OneTo(length(outax))
399+
else
400+
fill(1, length(outax))
401+
end
402+
end
403+
addaxes = map(addaxes) do outax
404+
if isnothing(DD.dims(destroydims, outax))
405+
outax
406+
else
407+
DD.reducedims(outax, DD.Dim)
408+
end
409+
end
295410
alloutaxes = (outaxes..., addaxes...)
296411
dimsmap = DD.dimnum(allinandoutdims, alloutaxes)
297412
alloutaxes, tupelize(dimsmap), (outwindows..., extrawindows...)
@@ -309,14 +424,15 @@ function xmap(f, ars::Union{YAXArrays.Cubes.YAXArray,DimWindowArray}...;
309424
end
310425
push!(outtypes, outtype)
311426

312-
sout = map(length,ax)
427+
sout = map(win->maximum(maximum,win),w)
313428
DAE.create_outwindows(sout;dimsmap=dm,windows = w)
314429
end
315-
daefunction = DAE.create_userfunction(f, (outtypes...,),
316-
is_mutating=inplace,
317-
allow_threads=false,
318-
args=function_args,
319-
kwargs=function_kwargs)
430+
daefunction = if f isa DAE.UserOp
431+
f
432+
else
433+
DAE.create_userfunction(f, (outtypes...,); is_mutating=inplace, allow_threads=false,
434+
args=function_args, kwargs=function_kwargs)
435+
end
320436
#Create DiskArrayEngine Input arrays
321437
input_arrays = map(ars) do ar
322438
a = to_windowarray(ar)
@@ -461,6 +577,9 @@ XFunction(f::XFunction;kwargs...) = f
461577

462578
default_inplace(f::XFunction) = f.inplace
463579
default_inplace(f) = true
580+
default_output(f) = XOutput()
581+
default_output(f::XFunction) = f.outputs
582+
464583

465584
function Base.broadcasted(f::XFunction,args...)
466585
xmap(f,args...,output = f.outputs, inplace = f.inplace)
@@ -474,6 +593,26 @@ function gmwop_from_conn(conn,nodes)
474593
DAE.GMDWop(inputs, outspecs, op)
475594
end
476595

596+
function compute(yax::YAXArray,args...;kwargs...)
597+
if isa(yax.data,DAE.GMWOPResult)
598+
r = if any(a->isa(a.a,DAE.GMWOPResult), yax.data.op.inars)
599+
g = DAE.MwopGraph()
600+
i = DAE.to_graph!(g, yax.data)
601+
DAE.remove_aliases!(g)
602+
DAE.fuse_graph!(g)
603+
newop = DAE.gmwop_from_reducedgraph(g);
604+
DAE.results_as_diskarrays(newop)[i]
605+
else
606+
yax.data
607+
end
608+
computed = DAE.compute(r,args...;kwargs...)
609+
DD.rebuild(yax,computed)
610+
else
611+
@warn "Yaxarray does not wrap a computation, nothing to do"
612+
yax
613+
end
614+
end
615+
477616
"""
478617
compute_to_zarr(ods, path; max_cache=5e8, overwrite=false)
479618

0 commit comments

Comments
 (0)