Skip to content

Commit 6be0ff6

Browse files
authored
Merge pull request #112 from invenia/reduced-oneto
Support reductions over dimensions with non-Number axes
2 parents 6fdaefb + e3ff94a commit 6be0ff6

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

src/core.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,11 @@ _reduced_indices(f, out, chosen) = out
342342
@inline __reduced_indices(f, out, ::Val{false}, chosen, ax, axs) =
343343
_reduced_indices(f, (out..., ax), chosen, axs...)
344344

345-
reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1)))
346-
reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))
345+
reduced_axis( ax::Axis{name,<:AbstractArray{T}}) where {name,T<:Number} = ax(oftype(ax.val, Base.OneTo(1)))
346+
reduced_axis0(ax::Axis{name,<:AbstractArray{T}}) where {name,T<:Number} = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))
347+
348+
reduced_axis( ax) = ax(Base.OneTo(1))
349+
reduced_axis0(ax) = ax(length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1))
347350

348351

349352
function Base.permutedims(A::AxisArray, perm)

test/core.jl

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,50 @@ for C in arrays
278278
@test C2 == reshape([1,2,3], 3, 1)
279279
@test C12 == reshape([1], 1, 1)
280280
end
281-
C1t = @inferred(op(C, Axis{:y}))
282-
@test C1t == C1
283-
C2t = @inferred(op(C, Axis{:x}))
284-
@test C2t == C2
285-
C12t = @inferred(op(C, (Axis{:y},Axis{:x})))
286-
@test C12t == C12
287-
C1t = @inferred(op(C, Axis{:y}()))
288-
@test C1t == C1
289-
C2t = @inferred(op(C, Axis{:x}()))
290-
@test C2t == C2
291-
C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}())))
292-
@test C12t == C12
281+
@test @inferred(op(C, Axis{:y})) == C1
282+
@test @inferred(op(C, Axis{:x})) == C2
283+
@test @inferred(op(C, (Axis{:y},Axis{:x}))) == C12
284+
@test @inferred(op(C, Axis{:y}())) == C1
285+
@test @inferred(op(C, Axis{:x}())) == C2
286+
@test @inferred(op(C, (Axis{:y}(),Axis{:x}()))) == C12
293287
end
294288
end
289+
290+
function typeof_noaxis(::AxisArray{T,N,D}) where {T,N,D}
291+
AxisArray{T,N,D}
292+
end
293+
294+
# uninferrable
295+
C = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}([:a,:b,:c]), Axis{:x}(["a","b","c","d","e"]))
296+
for op in functions # together, cover both reduced_indices and reduced_indices0
297+
axv = axisvalues(C)
298+
C1 = op(C, 1)
299+
@test typeof_noaxis(C1) == typeof_noaxis(C)
300+
@test axisnames(C1) == (:y,:x)
301+
@test axisvalues(C1) === (Base.OneTo(1), axv[2])
302+
C2 = op(C, 2)
303+
@test typeof_noaxis(C2) == typeof_noaxis(C)
304+
@test axisnames(C2) == (:y,:x)
305+
@test axisvalues(C2) === (axv[1], Base.OneTo(1))
306+
C12 = op(C, (1,2))
307+
@test typeof_noaxis(C12) == typeof_noaxis(C)
308+
@test axisnames(C12) == (:y,:x)
309+
@test axisvalues(C12) === (Base.OneTo(1), Base.OneTo(1))
310+
if op == sum
311+
@test C1 == [6 15 24 33 42]
312+
@test C2 == reshape([35,40,45], 3, 1)
313+
@test C12 == reshape([120], 1, 1)
314+
else
315+
@test C1 == [1 4 7 10 13]
316+
@test C2 == reshape([1,2,3], 3, 1)
317+
@test C12 == reshape([1], 1, 1)
318+
end
319+
@test @inferred(op(C, Axis{:y})) == C1
320+
@test @inferred(op(C, Axis{:x})) == C2
321+
# Unfortunately the type of (Axis{:y},Axis{:x}) is Tuple{UnionAll,UnionAll} so methods will not specialize
322+
@test_broken @inferred(op(C, (Axis{:y},Axis{:x}))) == C12
323+
@test op(C, (Axis{:y},Axis{:x})) == C12
324+
@test @inferred(op(C, Axis{:y}())) == C1
325+
@test @inferred(op(C, Axis{:x}())) == C2
326+
@test @inferred(op(C, (Axis{:y}(),Axis{:x}()))) == C12
327+
end

0 commit comments

Comments
 (0)