Skip to content

Commit 1d8fe52

Browse files
committed
Add tests and docs for to/from_parent_dims
1 parent 22490b7 commit 1d8fe52

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3"
3+
version = "3.0.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/dimensions.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
1818
end
1919
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()
2020

21+
"""
22+
from_parent_dims(::Type{T}) -> Bool
23+
24+
Returns the mapping from parent dimensions to child dimensions.
25+
"""
2126
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
2227
from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
2328
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I)
@@ -38,6 +43,11 @@ function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I
3843
return _val_to_static(Val(I))
3944
end
4045

46+
"""
47+
to_parent_dims(::Type{T}) -> Bool
48+
49+
Returns the mapping from child dimensions to parent dimensions.
50+
"""
4151
to_parent_dims(x) = to_parent_dims(typeof(x))
4252
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
4353
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())

test/dimensions.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
@testset "dimensions" begin
22

3+
@testset "dimension permutations" begin
4+
a = ones(2, 2, 2)
5+
perm = PermutedDimsArray(a, (3, 1, 2))
6+
mview = view(perm, :, 1, :)
7+
madj = mview'
8+
vview = view(madj, 1, :)
9+
vadj = vview'
10+
11+
@test @inferred(ArrayInterface.to_parent_dims(typeof(a))) == (1, 2, 3)
12+
@test @inferred(ArrayInterface.to_parent_dims(typeof(perm))) == (3, 1, 2)
13+
@test @inferred(ArrayInterface.to_parent_dims(typeof(mview))) == (1, 3)
14+
@test @inferred(ArrayInterface.to_parent_dims(typeof(madj))) == (2, 1)
15+
@test @inferred(ArrayInterface.to_parent_dims(typeof(vview))) == (2,)
16+
@test @inferred(ArrayInterface.to_parent_dims(typeof(vadj))) == (2, 1)
17+
18+
@test @inferred(ArrayInterface.from_parent_dims(typeof(a))) == (1, 2, 3)
19+
@test @inferred(ArrayInterface.from_parent_dims(typeof(perm))) == (2, 3, 1)
20+
@test @inferred(ArrayInterface.from_parent_dims(typeof(mview))) == (1, 0, 2)
21+
@test @inferred(ArrayInterface.from_parent_dims(typeof(madj))) == (2, 1)
22+
@test @inferred(ArrayInterface.from_parent_dims(typeof(vview))) == (0, 1)
23+
@test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2, 1)
24+
end
25+
326
@testset "to_dims" begin
427
@testset "small case" begin
528
@test ArrayInterface.to_dims((:x, :y), :x) == 1

0 commit comments

Comments
 (0)