Skip to content

Commit 8f8272d

Browse files
authored
Merge pull request #56 from JuliaArrays/teh/reductions
Reductions preserve the AxisArray wrapper (fixes #55)
2 parents c56dc40 + ff95732 commit 8f8272d

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

src/core.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,69 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S
280280
end
281281
end
282282

283+
# These methods allow us to preserve the AxisArray under reductions
284+
# Note that we only extend the following two methods, and then have it
285+
# dispatch to package-local `reduced_indices` and `reduced_indices0`
286+
# methods. This avoids a whole slew of ambiguities.
287+
if VERSION == v"0.5.0"
288+
Base.reduced_dims(A::AxisArray, region) = reduced_indices(axes(A), region)
289+
Base.reduced_dims0(A::AxisArray, region) = reduced_indices0(axes(A), region)
290+
else
291+
Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region)
292+
Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region)
293+
end
294+
295+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
296+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
297+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
298+
reduced_indices(axs, (region,))
299+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
300+
reduced_indices0(axs, (region,))
301+
302+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
303+
map((ax,d)->dregion ? reduced_axis(ax) : ax, axs, ntuple(identity, Val{N}))
304+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
305+
map((ax,d)->dregion ? reduced_axis0(ax) : ax, axs, ntuple(identity, Val{N}))
306+
307+
@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
308+
_reduced_indices(reduced_axis, (), region, axs...)
309+
@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
310+
_reduced_indices(reduced_axis0, (), region, axs...)
311+
@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) =
312+
_reduced_indices(reduced_axis, (), region, axs...)
313+
@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) =
314+
_reduced_indices(reduced_axis0, (), region, axs...)
315+
316+
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) =
317+
reduced_indices(reduced_indices(axs, region[1]), tail(region))
318+
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
319+
reduced_indices(reduced_indices(axs, region[1]), tail(region))
320+
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) =
321+
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
322+
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
323+
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
324+
325+
@pure samesym{n1,n2}(::Type{Axis{n1}}, ::Type{Axis{n2}}) = Val{n1==n2}()
326+
samesym{n1,n2,T1,T2}(::Type{Axis{n1,T1}}, ::Type{Axis{n2,T2}}) = samesym(Axis{n1},Axis{n2})
327+
samesym{n1,n2}(::Type{Axis{n1}}, ::Axis{n2}) = samesym(Axis{n1}, Axis{n2})
328+
samesym{n1,n2}(::Axis{n1}, ::Type{Axis{n2}}) = samesym(Axis{n1}, Axis{n2})
329+
samesym{n1,n2}(::Axis{n1}, ::Axis{n2}) = samesym(Axis{n1}, Axis{n2})
330+
331+
@inline _reduced_indices{Ax<:Axis}(f, out, chosen::Type{Ax}, ax::Axis, axs...) =
332+
__reduced_indices(f, out, samesym(chosen, ax), chosen, ax, axs)
333+
@inline _reduced_indices(f, out, chosen::Axis, ax::Axis, axs...) =
334+
__reduced_indices(f, out, samesym(chosen, ax), chosen, ax, axs)
335+
_reduced_indices(f, out, chosen) = out
336+
337+
@inline __reduced_indices(f, out, ::Val{true}, chosen, ax, axs) =
338+
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
339+
@inline __reduced_indices(f, out, ::Val{false}, chosen, ax, axs) =
340+
_reduced_indices(f, (out..., ax), chosen, axs...)
341+
342+
reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1)))
343+
reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))
344+
345+
283346
function Base.permutedims(A::AxisArray, perm)
284347
p = permutation(perm, axisnames(A))
285348
AxisArray(permutedims(A.data, p), axes(A)[[p...]])

test/core.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,45 @@ map!(*, A2, A, A)
225225
@test isa(A2, AxisArray)
226226
@test A2.axes == A.axes
227227
@test A2.data == A.data .* A.data
228+
229+
# Reductions (issue #55)
230+
A = AxisArray(collect(reshape(1:15,3,5)), :y, :x)
231+
B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50))
232+
for C in (A, B)
233+
for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0
234+
axv = axisvalues(C)
235+
C1 = @inferred(op(C, 1))
236+
@test typeof(C1) == typeof(C)
237+
@test axisnames(C1) == (:y,:x)
238+
@test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2])
239+
C2 = op(C, 2)
240+
@test typeof(C2) == typeof(C)
241+
@test axisnames(C2) == (:y,:x)
242+
@test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1)))
243+
C12 = @inferred(op(C, (1,2)))
244+
@test typeof(C12) == typeof(C)
245+
@test axisnames(C12) == (:y,:x)
246+
@test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1)))
247+
if op == sum
248+
@test C1 == [6 15 24 33 42]
249+
@test C2 == reshape([35,40,45], 3, 1)
250+
@test C12 == reshape([120], 1, 1)
251+
else
252+
@test C1 == [1 4 7 10 13]
253+
@test C2 == reshape([1,2,3], 3, 1)
254+
@test C12 == reshape([1], 1, 1)
255+
end
256+
C1t = @inferred(op(C, Axis{:y}))
257+
@test C1t == C1
258+
C2t = @inferred(op(C, Axis{:x}))
259+
@test C2t == C2
260+
C12t = @inferred(op(C, (Axis{:y},Axis{:x})))
261+
@test C12t == C12
262+
C1t = @inferred(op(C, Axis{:y}()))
263+
@test C1t == C1
264+
C2t = @inferred(op(C, Axis{:x}()))
265+
@test C2t == C2
266+
C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}())))
267+
@test C12t == C12
268+
end
269+
end

0 commit comments

Comments
 (0)