Skip to content

Commit 7e15abc

Browse files
authored
Merge pull request #93 from Tokazama/master
Easier use of ArrayInterface.size/axes/to_dims
2 parents 70bd07f + 6ad6877 commit 7e15abc

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.14.2"
3+
version = "2.14.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/dimensions.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,14 @@ end
7171
7272
This returns the dimension(s) of `x` corresponding to `d`.
7373
"""
74-
to_dims(x, d::Integer) = Int(d)
75-
to_dims(x, d::Colon) = d # `:` is the default for most methods that take `dims`
76-
@inline to_dims(x, d::Tuple) = map(i -> to_dims(x, i), d)
77-
@inline function to_dims(x, d::Symbol)::Int
78-
i = _sym_to_dim(dimnames(x), d)
74+
to_dims(x, d) = to_dims(dimnames(x), d)
75+
to_dims(x::Tuple{Vararg{Symbol}}, d::Integer) = Int(d)
76+
to_dims(x::Tuple{Vararg{Symbol}}, d::Colon) = d # `:` is the default for most methods that take `dims`
77+
@inline to_dims(x::Tuple{Vararg{Symbol}}, d::Tuple) = map(i -> to_dims(x, i), d)
78+
@inline function to_dims(x::Tuple{Vararg{Symbol}}, d::Symbol)::Int
79+
i = _sym_to_dim(x, d)
7980
if i === 0
80-
throw(ArgumentError("Specified name ($(repr(d))) does not match any dimension name ($(dimnames(x)))"))
81+
throw(ArgumentError("Specified name ($(repr(d))) does not match any dimension name ($(x))"))
8182
end
8283
return i
8384
end
@@ -152,6 +153,7 @@ julia> ArrayInterface.size(A)
152153
```
153154
"""
154155
@inline size(A) = Base.size(A)
156+
@inline size(A, d::Integer) = size(A)[Int(d)]
155157
@inline size(A, d) = Base.size(A, to_dims(A, d))
156158
@inline size(x::LinearAlgebra.Adjoint{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
157159
@inline size(x::LinearAlgebra.Transpose{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
@@ -161,7 +163,8 @@ julia> ArrayInterface.size(A)
161163
162164
Return a valid range that maps to each index along dimension `d` of `A`.
163165
"""
164-
@inline axes(A, d) = Base.axes(A, to_dims(A, d))
166+
@inline axes(A, d) = axes(A, to_dims(A, d))
167+
@inline axes(A, d::Integer) = axes(A)[Int(d)]
165168

166169
"""
167170
axes(A)

test/dimensions.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11

22
@testset "dimensions" begin
33

4+
@testset "to_dims" begin
5+
@testset "small case" begin
6+
@test ArrayInterface.to_dims((:x, :y), :x) == 1
7+
@test ArrayInterface.to_dims((:x, :y), :y) == 2
8+
@test_throws ArgumentError ArrayInterface.to_dims((:x, :y), :z) # not found
9+
end
10+
11+
@testset "large case" begin
12+
@test ArrayInterface.to_dims((:x, :y, :a, :b, :c, :d), :x) == 1
13+
@test ArrayInterface.to_dims((:x, :y, :a, :b, :c, :d), :a) == 3
14+
@test ArrayInterface.to_dims((:x, :y, :a, :b, :c, :d), :d) == 6
15+
@test_throws ArgumentError ArrayInterface.to_dims((:x, :y, :a, :b, :c, :d), :z) # not found
16+
end
17+
end
18+
419
struct NamedDimsWrapper{L,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
520
parent::P
621
NamedDimsWrapper{L}(p) where {L} = new{L,eltype(p),ndims(p),typeof(p)}(p)
@@ -10,8 +25,11 @@ ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true
1025
ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = L
1126
Base.parent(x::NamedDimsWrapper) = x.parent
1227
Base.size(x::NamedDimsWrapper) = size(parent(x))
28+
Base.size(x::NamedDimsWrapper, d) = ArrayInterface.size(x, d)
1329
Base.axes(x::NamedDimsWrapper) = axes(parent(x))
30+
Base.axes(x::NamedDimsWrapper, d) = ArrayInterface.axes(x, d)
1431
Base.strides(x::NamedDimsWrapper) = Base.strides(parent(x))
32+
Base.strides(x::NamedDimsWrapper, d) = ArrayInterface.strides(x, d)
1533

1634
Base.getindex(x::NamedDimsWrapper; kwargs...) = ArrayInterface.getindex(x; kwargs...)
1735
Base.getindex(x::NamedDimsWrapper, args...) = ArrayInterface.getindex(x, args...)
@@ -46,9 +64,9 @@ dnums = ntuple(+, length(d))
4664
@test @inferred(ArrayInterface.to_dims(x, reverse(d))) === reverse(dnums)
4765
@test_throws ArgumentError ArrayInterface.to_dims(x, :z)
4866

49-
@test @inferred(ArrayInterface.size(x, :x)) == size(parent(x), 1)
50-
@test @inferred(ArrayInterface.axes(x, :x)) == axes(parent(x), 1)
51-
@test ArrayInterface.strides(x, :x) == ArrayInterface.strides(parent(x))[1]
67+
@test @inferred(size(x, :x)) == size(parent(x), 1)
68+
@test @inferred(axes(x, :x)) == axes(parent(x), 1)
69+
@test strides(x, :x) == ArrayInterface.strides(parent(x))[1]
5270

5371
x[x = 1] = [2, 3]
5472
@test @inferred(getindex(x, x = 1)) == [2, 3]

0 commit comments

Comments
 (0)