Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/DAT/xmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ function mapslices(f, d::YAXArray, addargs...; dims, kwargs...)
Symbol(d)=>Whole()
end
w = windows(d,dw...)
xmap(f,w,inplace=false)
outaxes = YAXArrays.getOutAxis((YAXArrays.ByInference(),), Symbol.(dims), (d,), addargs, f)

xmap(f, w, output=XOutput(outaxes...), inplace=false)
end

struct XFunction{F,O,I} <: Function
Expand Down
30 changes: 17 additions & 13 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,21 @@ match_axis(a, ax) = match_axis(get_descriptor(a), ax)
getOutAxis(desc, axlist, incubes, pargs, f) = getAxis(desc, unique(axlist))

function getOutAxis(desc::Tuple{ByInference}, axlist, incubes, pargs, f)
inAxes = map(DD.dims, incubes)
inAxSmall = map(i -> filter(j -> in(j, axlist), i) |> collect, inAxes)
axlist = map(axlist) do ax
isa(ax, String) ? Symbol(ax) : ax
end
inAxSmall = map(i -> DD.dims(i, axlist), incubes)
inSizes = map(i -> (map(length, i)...,), inAxSmall)
intypes = map(eltype, incubes)
testars = map((s, it) -> zeros(it, s...), inSizes, intypes)
map(testars) do ta
ta .= rand(Base.nonmissingtype(eltype(ta)), size(ta)...)
if eltype(ta) >: Missing
# Add some missings
randind = rand(1:length(ta), length(ta) ÷ 10)
ta[randind] .= missing
end
end
# map(testars) do ta
# ta .= rand(Base.nonmissingtype(eltype(ta)), size(ta)...)
# if eltype(ta) >: Missing
# # Add some missings
# randind = rand(1:length(ta), length(ta) ÷ 10)
# ta[randind] .= missing
# end
# end
resu = f(testars..., pargs...)
isa(resu, AbstractArray) ||
isa(resu, Number) ||
Expand All @@ -102,21 +104,23 @@ function getOutAxis(desc::Tuple{ByInference}, axlist, incubes, pargs, f)
end
end
outsizes = size(resu)
allAxes = reduce(union, inAxSmall)
outaxes = map(outsizes, 1:length(outsizes)) do s, il
if s > 2
i = findall(i -> i == s, length.(axlist))
i = findall(i -> i == s, length.(allAxes))
if length(i) == 1
return axlist[i[1]]
return allAxes[i[1]]
elseif length(i) > 1
@info "Found multiple matching axes for output dimension $il"
end
end
return Dim{Symbol("OutAxis$(il)")}( 1:s)
return Dim{Symbol("OutAxis$(il)")}(1:s)
end
if !allunique(outaxes)
#TODO: fallback with axis renaming in this case
error("Could not determine unique output axes from output shape")
end
@show outaxes
return (outaxes...,)
end

Expand Down
Loading