@@ -280,6 +280,69 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S
280
280
end
281
281
end
282
282
283
+ # These methods allow us to preserve the AxisArray under reductions
284
+ # Note that we only extend the following two methods, and then have it
285
+ # dispatch to package-local `reduced_indices` and `reduced_indices0`
286
+ # methods. This avoids a whole slew of ambiguities.
287
+ if VERSION == v " 0.5.0"
288
+ Base. reduced_dims (A:: AxisArray , region) = reduced_indices (axes (A), region)
289
+ Base. reduced_dims0 (A:: AxisArray , region) = reduced_indices0 (axes (A), region)
290
+ else
291
+ Base. reduced_indices (A:: AxisArray , region) = reduced_indices (axes (A), region)
292
+ Base. reduced_indices0 (A:: AxisArray , region) = reduced_indices0 (axes (A), region)
293
+ end
294
+
295
+ reduced_indices {N} (axs:: Tuple{Vararg{Axis,N}} , :: Tuple{} ) = axs
296
+ reduced_indices0 {N} (axs:: Tuple{Vararg{Axis,N}} , :: Tuple{} ) = axs
297
+ reduced_indices {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Integer ) =
298
+ reduced_indices (axs, (region,))
299
+ reduced_indices0 {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Integer ) =
300
+ reduced_indices0 (axs, (region,))
301
+
302
+ reduced_indices {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Dims ) =
303
+ map ((ax,d)-> d∈ region ? reduced_axis (ax) : ax, axs, ntuple (identity, Val{N}))
304
+ reduced_indices0 {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Dims ) =
305
+ map ((ax,d)-> d∈ region ? reduced_axis0 (ax) : ax, axs, ntuple (identity, Val{N}))
306
+
307
+ @inline reduced_indices {Ax<:Axis} (axs:: Tuple{Vararg{Axis}} , region:: Type{Ax} ) =
308
+ _reduced_indices (reduced_axis, (), region, axs... )
309
+ @inline reduced_indices0 {Ax<:Axis} (axs:: Tuple{Vararg{Axis}} , region:: Type{Ax} ) =
310
+ _reduced_indices (reduced_axis0, (), region, axs... )
311
+ @inline reduced_indices (axs:: Tuple{Vararg{Axis}} , region:: Axis ) =
312
+ _reduced_indices (reduced_axis, (), region, axs... )
313
+ @inline reduced_indices0 (axs:: Tuple{Vararg{Axis}} , region:: Axis ) =
314
+ _reduced_indices (reduced_axis0, (), region, axs... )
315
+
316
+ reduced_indices (axs:: Tuple{Vararg{Axis}} , region:: Tuple ) =
317
+ reduced_indices (reduced_indices (axs, region[1 ]), tail (region))
318
+ reduced_indices (axs:: Tuple{Vararg{Axis}} , region:: Tuple{Vararg{Axis}} ) =
319
+ reduced_indices (reduced_indices (axs, region[1 ]), tail (region))
320
+ reduced_indices0 (axs:: Tuple{Vararg{Axis}} , region:: Tuple ) =
321
+ reduced_indices0 (reduced_indices0 (axs, region[1 ]), tail (region))
322
+ reduced_indices0 (axs:: Tuple{Vararg{Axis}} , region:: Tuple{Vararg{Axis}} ) =
323
+ reduced_indices0 (reduced_indices0 (axs, region[1 ]), tail (region))
324
+
325
+ @pure samesym {n1,n2} (:: Type{Axis{n1}} , :: Type{Axis{n2}} ) = Val {n1==n2} ()
326
+ samesym {n1,n2,T1,T2} (:: Type{Axis{n1,T1}} , :: Type{Axis{n2,T2}} ) = samesym (Axis{n1},Axis{n2})
327
+ samesym {n1,n2} (:: Type{Axis{n1}} , :: Axis{n2} ) = samesym (Axis{n1}, Axis{n2})
328
+ samesym {n1,n2} (:: Axis{n1} , :: Type{Axis{n2}} ) = samesym (Axis{n1}, Axis{n2})
329
+ samesym {n1,n2} (:: Axis{n1} , :: Axis{n2} ) = samesym (Axis{n1}, Axis{n2})
330
+
331
+ @inline _reduced_indices {Ax<:Axis} (f, out, chosen:: Type{Ax} , ax:: Axis , axs... ) =
332
+ __reduced_indices (f, out, samesym (chosen, ax), chosen, ax, axs)
333
+ @inline _reduced_indices (f, out, chosen:: Axis , ax:: Axis , axs... ) =
334
+ __reduced_indices (f, out, samesym (chosen, ax), chosen, ax, axs)
335
+ _reduced_indices (f, out, chosen) = out
336
+
337
+ @inline __reduced_indices (f, out, :: Val{true} , chosen, ax, axs) =
338
+ _reduced_indices (f, (out... , f (ax)), chosen, axs... )
339
+ @inline __reduced_indices (f, out, :: Val{false} , chosen, ax, axs) =
340
+ _reduced_indices (f, (out... , ax), chosen, axs... )
341
+
342
+ reduced_axis (ax) = ax (oftype (ax. val, Base. OneTo (1 )))
343
+ reduced_axis0 (ax) = ax (oftype (ax. val, length (ax. val) == 0 ? Base. OneTo (0 ) : Base. OneTo (1 )))
344
+
345
+
283
346
function Base. permutedims (A:: AxisArray , perm)
284
347
p = permutation (perm, axisnames (A))
285
348
AxisArray (permutedims (A. data, p), axes (A)[[p... ]])
0 commit comments