Skip to content

Commit 4ebd62a

Browse files
authored
Merge pull request #257 from JuliaArrays/subarrayofadjointvector
Infer static size of subarrays of adjoint vectors
2 parents afd0eac + 887f157 commit 4ebd62a

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "5.0.5"
3+
version = "5.0.6"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/axes.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,16 @@ function axes_types(::Type{T}) where {T<:AbstractRange}
4848
end
4949
end
5050
axes_types(::Type{T}) where {T<:ReshapedArray} = NTuple{ndims(T),OneTo{Int}}
51-
function _sub_axis_type(::Type{I}, dim::StaticInt{D}) where {I<:Tuple,D}
52-
axes_types(field_type(I, dim), static(1))
51+
function _sub_axis_type(::Type{PA}, ::Type{I}, dim::StaticInt{D}) where {I<:Tuple,PA,D}
52+
IT = field_type(I, dim)
53+
if IT <: Base.Slice
54+
axes_types(field_type(PA, dim), static(1))
55+
else
56+
axes_types(IT, static(1))
57+
end
5358
end
5459
@inline function axes_types(::Type{T}) where {N,P,I,T<:SubArray{<:Any,N,P,I}}
55-
return eachop_tuple(_sub_axis_type, to_parent_dims(T), I)
60+
return eachop_tuple(_sub_axis_type, to_parent_dims(T), axes_types(P), I)
5661
end
5762

5863
function axes_types(::Type{T}) where {T<:ReinterpretArray}

test/axes.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,14 @@ end
6767
@test @inferred(ArrayInterface.axes(b, static(3))) == 1:1
6868
end
6969

70+
@testset "SubArray Adjoint Axis" begin
71+
N = 4; d = rand(N);
72+
73+
@test @inferred(ArrayInterface.axes_types(typeof(view(d',:,1:2)))) === Tuple{ArrayInterface.OptionallyStaticUnitRange{StaticInt{1}, StaticInt{1}}, Base.OneTo{Int64}}
74+
75+
end
7076
if isdefined(Base, :ReshapedReinterpretArray)
77+
@testset "ReshapedReinterpretArray" begin
7178
a = rand(3, 5)
7279
ua = reinterpret(reshape, UInt64, a)
7380
@test ArrayInterface.axes(ua) === ArrayInterface.axes(a)
@@ -79,4 +86,5 @@ if isdefined(Base, :ReshapedReinterpretArray)
7986
@test @inferred(ArrayInterface.axes(u8a, static(2))) isa ArrayInterface.axes_types(u8a, 2)
8087
fa = reinterpret(reshape, Float64, copy(u8a))
8188
@inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa)
89+
end
8290
end

0 commit comments

Comments
 (0)