Skip to content

Commit a48dcad

Browse files
authored
trailing dimensions in eachslice (#58791)
fixes #51692
1 parent a452acc commit a48dcad

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

base/slicearray.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct Slices{P,SM,AX,S,N} <: AbstractSlices{S,N}
2525
"""
2626
parent::P
2727
"""
28-
A tuple of length `ndims(parent)`, denoting how each dimension should be handled:
28+
A tuple of length at least `ndims(parent)`, denoting how each dimension should be handled:
2929
- an integer `i`: this is the `i`th dimension of the outer `Slices` object.
3030
- `:`: an "inner" dimension
3131
"""
@@ -39,34 +39,39 @@ end
3939
unitaxis(::AbstractArray) = Base.OneTo(1)
4040

4141
function Slices(A::P, slicemap::SM, ax::AX) where {P,SM,AX}
42+
length(slicemap) >= ndims(A) ||
43+
throw(ArgumentError("Slices cannot be constructed with a slicemap of fewer elements than the parent has dimensions"))
4244
N = length(ax)
43-
argT = map((a,l) -> l === (:) ? Colon : eltype(a), axes(A), slicemap)
45+
parent_axes = ntuple(d -> axes(A, d), length(slicemap))
46+
argT = map((a,l) -> l === (:) ? Colon : eltype(a), parent_axes, slicemap)
4447
S = Base.promote_op(view, P, argT...)
4548
Slices{P,SM,AX,S,N}(A, slicemap, ax)
4649
end
4750

4851
_slice_check_dims(N) = nothing
4952
function _slice_check_dims(N, dim, dims...)
50-
1 <= dim <= N || throw(DimensionMismatch("Invalid dimension $dim"))
53+
1 <= dim || throw(DimensionMismatch("Invalid dimension $dim"))
5154
dim in dims && throw(DimensionMismatch("Dimensions $dims are not unique"))
5255
_slice_check_dims(N,dims...)
5356
end
5457

5558
@constprop :aggressive function _eachslice(A::AbstractArray{T,N}, dims::NTuple{M,Integer}, drop::Bool) where {T,N,M}
5659
_slice_check_dims(N,dims...)
60+
N_ = foldl(max, dims; init=N)
61+
5762
if drop
5863
# if N = 4, dims = (3,1) then
5964
# axes = (axes(A,3), axes(A,1))
6065
# slicemap = (2, :, 1, :)
6166
ax = map(dim -> axes(A,dim), dims)
62-
slicemap = ntuple(dim -> something(findfirst(isequal(dim), dims), (:)), N)
67+
slicemap = ntuple(dim -> something(findfirst(isequal(dim), dims), (:)), N_)
6368
return Slices(A, slicemap, ax)
6469
else
6570
# if N = 4, dims = (3,1) then
6671
# axes = (axes(A,1), OneTo(1), axes(A,3), OneTo(1))
6772
# slicemap = (1, :, 3, :)
68-
ax = ntuple(dim -> dim in dims ? axes(A,dim) : unitaxis(A), N)
69-
slicemap = ntuple(dim -> dim in dims ? dim : (:), N)
73+
ax = ntuple(dim -> dim in dims ? axes(A,dim) : unitaxis(A), N_)
74+
slicemap = ntuple(dim -> dim in dims ? dim : (:), N_)
7075
return Slices(A, slicemap, ax)
7176
end
7277
end

test/arrayops.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2402,7 +2402,7 @@ end
24022402
M = [1 2 3; 4 5 6; 7 8 9]
24032403
@test eachrow(M) == eachslice(M, dims = 1) == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
24042404
@test eachcol(M) == eachslice(M, dims = 2) == [[1, 4, 7], [2, 5, 8], [3, 6, 9]]
2405-
@test_throws DimensionMismatch eachslice(M, dims = 4)
2405+
@test eachslice(M, dims = 4) == [[1 2 3; 4 5 6; 7 8 9;;;]]
24062406

24072407
SR = @inferred eachrow(M)
24082408
@test SR[2] isa eltype(SR)
@@ -2467,6 +2467,32 @@ end
24672467
@test_throws BoundsError A[2,3] = [4,5]
24682468
@test_throws BoundsError A[2,3] .= [4,5]
24692469
end
2470+
2471+
@testset "trailing dimensions" begin
2472+
v = collect(1:3)
2473+
2474+
S2 = eachslice(v; dims = 2, drop=true)
2475+
@test S2 isa AbstractSlices{<:AbstractVector, 1}
2476+
@test size(S2) == (1,)
2477+
@test S2[1] == v
2478+
2479+
S2K = eachslice(v; dims = 2, drop=false)
2480+
@test S2K isa AbstractSlices{<:AbstractVector, 2}
2481+
@test size(S2K) == (1,1)
2482+
@test S2K[1,1] == v
2483+
2484+
M = reshape(1:6, 2, 3)
2485+
2486+
S13 = eachslice(M; dims = (1,3))
2487+
@test size(S13) == (2,1)
2488+
@test S13[2,1] == M[2,:,1]
2489+
2490+
S13K = eachslice(M; dims = (1,3), drop=false)
2491+
@test size(S13K) == (2,1,1)
2492+
@test S13K[1,1,1] == M[1,:]
2493+
@test S13K[2,1,1] == M[2,:]
2494+
end
2495+
24702496
end
24712497

24722498
###

0 commit comments

Comments
 (0)