Skip to content

Commit 24dea7c

Browse files
committed
Reductions preserve the AxisArray wrapper (fixes #55)
1 parent c56dc40 commit 24dea7c

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

src/core.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,55 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S
280280
end
281281
end
282282

283+
# These methods allow us to preserve the AxisArray under reductions
284+
# Note that we only extend the following two methods, and then have it
285+
# dispatch to package-local `reduced_indices` and `reduced_indices0`
286+
# methods. This avoids a whole slew of ambiguities.
287+
Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region)
288+
Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region)
289+
290+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
291+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
292+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
293+
reduced_indices(axs, (region,))
294+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
295+
reduced_indices0(axs, (region,))
296+
297+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
298+
map((ax,d)->dregion ? reduced_axis(ax) : ax, axs, ntuple(identity, Val{N}))
299+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
300+
map((ax,d)->dregion ? reduced_axis0(ax) : ax, axs, ntuple(identity, Val{N}))
301+
302+
@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
303+
_reduced_indices(reduced_axis, (), region, axs...)
304+
@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) =
305+
_reduced_indices(reduced_axis, (), region, axs...)
306+
@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
307+
_reduced_indices(reduced_axis0, (), region, axs...)
308+
@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) =
309+
_reduced_indices(reduced_axis0, (), region, axs...)
310+
311+
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) =
312+
reduced_indices(reduced_indices(axs, region[1]), tail(region))
313+
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
314+
reduced_indices(reduced_indices(axs, region[1]), tail(region))
315+
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) =
316+
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
317+
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
318+
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
319+
320+
@inline _reduced_indices{name}(f, out, chosen::Type{Axis{name}}, ax::Axis{name}, axs...) =
321+
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
322+
@inline _reduced_indices{name}(f, out, chosen::Axis{name}, ax::Axis{name}, axs...) =
323+
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
324+
@inline _reduced_indices(f, out, chosen, ax::Axis, axs...) =
325+
_reduced_indices(f, (out..., ax), chosen, axs...)
326+
_reduced_indices(f, out, chosen) = out
327+
328+
reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1)))
329+
reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))
330+
331+
283332
function Base.permutedims(A::AxisArray, perm)
284333
p = permutation(perm, axisnames(A))
285334
AxisArray(permutedims(A.data, p), axes(A)[[p...]])

test/core.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,45 @@ map!(*, A2, A, A)
225225
@test isa(A2, AxisArray)
226226
@test A2.axes == A.axes
227227
@test A2.data == A.data .* A.data
228+
229+
# Reductions (issue #55)
230+
A = AxisArray(collect(reshape(1:15,3,5)), :y, :x)
231+
B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50))
232+
for C in (A, B)
233+
for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0
234+
axv = axisvalues(C)
235+
C1 = @inferred(op(C, 1))
236+
@test typeof(C1) == typeof(C)
237+
@test axisnames(C1) == (:y,:x)
238+
@test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2])
239+
C2 = op(C, 2)
240+
@test typeof(C2) == typeof(C)
241+
@test axisnames(C2) == (:y,:x)
242+
@test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1)))
243+
C12 = @inferred(op(C, (1,2)))
244+
@test typeof(C12) == typeof(C)
245+
@test axisnames(C12) == (:y,:x)
246+
@test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1)))
247+
if op == sum
248+
@test C1 == [6 15 24 33 42]
249+
@test C2 == reshape([35,40,45], 3, 1)
250+
@test C12 == reshape([120], 1, 1)
251+
else
252+
@test C1 == [1 4 7 10 13]
253+
@test C2 == reshape([1,2,3], 3, 1)
254+
@test C12 == reshape([1], 1, 1)
255+
end
256+
C1t = @inferred(op(C, Axis{:y}))
257+
@test C1t == C1
258+
C2t = @inferred(op(C, Axis{:x}))
259+
@test C2t == C2
260+
C12t = @inferred(op(C, (Axis{:y},Axis{:x})))
261+
@test C12t == C12
262+
C1t = @inferred(op(C, Axis{:y}()))
263+
@test C1t == C1
264+
C2t = @inferred(op(C, Axis{:x}()))
265+
@test C2t == C2
266+
C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}())))
267+
@test C12t == C12
268+
end
269+
end

0 commit comments

Comments
 (0)