Skip to content

Commit f812632

Browse files
authored
Fix mapslices (#553)
* fix mapslices * add some tests * add import and move to proper location * Require DIskArrays 0.4.18 and use dropdims * add dropdims keyword to mapslices
1 parent 71a1278 commit f812632

File tree

4 files changed

+50
-16
lines changed

4 files changed

+50
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ DataStructures = "0.17, 0.18, 0.19"
3838
DimensionalData = "0.27, 0.28, 0.29"
3939
DiskArrayEngine = "0.2"
4040
DiskArrayTools = "0.1.12"
41-
DiskArrays = "0.3, 0.4.10"
41+
DiskArrays = "0.4.18"
4242
DocStringExtensions = "0.8, 0.9"
4343
Glob = "1.3"
4444
Interpolations = "0.12, 0.13, 0.14, 0.15, 0.16"

src/DAT/xmap.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,20 @@ compute_to_zarr(ds, "output.zarr")
406406
function xmap end
407407

408408
import Base.mapslices
409-
function mapslices(f, d::YAXArray, addargs...; dims, kwargs...)
409+
function mapslices(f, d::YAXArray, addargs...; dims, dropdims=false, kwargs...)
410410
!isa(dims, Tuple) && (dims = (dims,))
411411
dw = map(dims) do d
412412
Symbol(d)=>Whole()
413413
end
414414
w = windows(d,dw...)
415-
xmap(f,w,inplace=false)
415+
outaxes = YAXArrays.getOutAxis((YAXArrays.ByInference(),), Symbol.(dims), (d,), addargs, f)
416+
417+
returncube = xmap(f, w, output=XOutput(outaxes...), inplace=false)
418+
if dropdims
419+
return Base.dropdims(returncube; dims=Symbol.(dims))
420+
else
421+
return returncube
422+
end
416423
end
417424

418425
struct XFunction{F,O,I} <: Function

src/helpers.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,21 @@ match_axis(a, ax) = match_axis(get_descriptor(a), ax)
7878
getOutAxis(desc, axlist, incubes, pargs, f) = getAxis(desc, unique(axlist))
7979

8080
function getOutAxis(desc::Tuple{ByInference}, axlist, incubes, pargs, f)
81-
inAxes = map(DD.dims, incubes)
82-
inAxSmall = map(i -> filter(j -> in(j, axlist), i) |> collect, inAxes)
81+
axlist = map(axlist) do ax
82+
isa(ax, String) ? Symbol(ax) : ax
83+
end
84+
inAxSmall = map(i -> DD.dims(i, axlist), incubes)
8385
inSizes = map(i -> (map(length, i)...,), inAxSmall)
8486
intypes = map(eltype, incubes)
8587
testars = map((s, it) -> zeros(it, s...), inSizes, intypes)
86-
map(testars) do ta
87-
ta .= rand(Base.nonmissingtype(eltype(ta)), size(ta)...)
88-
if eltype(ta) >: Missing
89-
# Add some missings
90-
randind = rand(1:length(ta), length(ta) ÷ 10)
91-
ta[randind] .= missing
92-
end
93-
end
88+
# map(testars) do ta
89+
# ta .= rand(Base.nonmissingtype(eltype(ta)), size(ta)...)
90+
# if eltype(ta) >: Missing
91+
# # Add some missings
92+
# randind = rand(1:length(ta), length(ta) ÷ 10)
93+
# ta[randind] .= missing
94+
# end
95+
# end
9496
resu = f(testars..., pargs...)
9597
isa(resu, AbstractArray) ||
9698
isa(resu, Number) ||
@@ -102,16 +104,17 @@ function getOutAxis(desc::Tuple{ByInference}, axlist, incubes, pargs, f)
102104
end
103105
end
104106
outsizes = size(resu)
107+
allAxes = reduce(union, inAxSmall)
105108
outaxes = map(outsizes, 1:length(outsizes)) do s, il
106109
if s > 2
107-
i = findall(i -> i == s, length.(axlist))
110+
i = findall(i -> i == s, length.(allAxes))
108111
if length(i) == 1
109-
return axlist[i[1]]
112+
return allAxes[i[1]]
110113
elseif length(i) > 1
111114
@info "Found multiple matching axes for output dimension $il"
112115
end
113116
end
114-
return Dim{Symbol("OutAxis$(il)")}( 1:s)
117+
return Dim{Symbol("OutAxis$(il)")}(1:s)
115118
end
116119
if !allunique(outaxes)
117120
#TODO: fallback with axis renaming in this case

test/Datasets/datasets.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,30 @@ end
528528
mean_slice = mapslices(mean, cube; dims="Dim_1")
529529

530530
@test mean_slice[1, :, :] == ones(20, 5)
531+
import DimensionalData as DD
532+
a = YAXArray(reshape(1:1000, 10, 20, 5))
533+
b = mapslices(cumsum, a, dims="Dim_1")
534+
@test size(b) == size(a)
535+
@test DD.dims(b) == DD.dims(a)
536+
@test b[3, 1, 1] == 6
537+
@test b[2, 2, :].data == 23:400:1623
538+
539+
c = mapslices(sum, a, dims="Dim_2")
540+
@test size(c) == (1, 10, 5)
541+
@test c.Dim_2 == DD.rebuild(a.Dim_2, [a.Dim_2.val])
542+
@test c.Dim_1 == a.Dim_1
543+
@test c.Dim_3 == a.Dim_3
544+
@test c[1, 3, 3] == 9960
545+
546+
#Test dropdims as well
547+
d = dropdims(c, dims=:Dim_2)
548+
DD.dims(d) == Base.tail(DD.dims(c))
549+
550+
d = mapslices(sum, a, dims="Dim_2", dropdims=true)
551+
@test size(d) == (10, 5)
552+
@test d.Dim_1 == a.Dim_1
553+
@test d.Dim_3 == a.Dim_3
554+
@test d[3, 3] == 9960
531555
end
532556

533557
@testset "Making Cubes from heterogemous data types" begin

0 commit comments

Comments
 (0)