Skip to content

Commit e0af2eb

Browse files
committed
Infer static size of subarrays of adjoint vectors
1 parent b2889ed commit e0af2eb

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/axes.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,16 @@ function axes_types(::Type{T}) where {T<:AbstractRange}
4848
end
4949
end
5050
axes_types(::Type{T}) where {T<:ReshapedArray} = NTuple{ndims(T),OneTo{Int}}
51-
function _sub_axis_type(::Type{I}, dim::StaticInt{D}) where {I<:Tuple,D}
52-
axes_types(field_type(I, dim), static(1))
51+
function _sub_axis_type(::Type{I}, ::Type{PA}, dim::StaticInt{D}) where {I<:Tuple,PA,D}
52+
IT = field_type(I, dim)
53+
if IT <: Base.Slice
54+
axes_types(field_type(PA, dim), static(1))
55+
else
56+
axes_types(IT, static(1))
57+
end
5358
end
5459
@inline function axes_types(::Type{T}) where {N,P,I,T<:SubArray{<:Any,N,P,I}}
55-
return eachop_tuple(_sub_axis_type, to_parent_dims(T), I)
60+
return eachop_tuple(_sub_axis_type, to_parent_dims(T), axes_types(P), I)
5661
end
5762

5863
function axes_types(::Type{T}) where {T<:ReinterpretArray}

test/axes.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,14 @@ end
6767
@test @inferred(ArrayInterface.axes(b, static(3))) == 1:1
6868
end
6969

70+
@testset "SubArray Adjoint Axis" begin
71+
N = 4; d = rand(N);
72+
73+
@test @inferred(ArrayInterface.axes_types(typeof(view(d',:,1:2)))) === Tuple{ArrayInterface.OptionallyStaticUnitRange{Static.StaticInt{1}, Static.StaticInt{1}}, Base.OneTo{Int64}}
74+
75+
end
7076
if isdefined(Base, :ReshapedReinterpretArray)
77+
@testset "ReshapedReinterpretArray" begin
7178
a = rand(3, 5)
7279
ua = reinterpret(reshape, UInt64, a)
7380
@test ArrayInterface.axes(ua) === ArrayInterface.axes(a)
@@ -79,4 +86,5 @@ if isdefined(Base, :ReshapedReinterpretArray)
7986
@test @inferred(ArrayInterface.axes(u8a, static(2))) isa ArrayInterface.axes_types(u8a, 2)
8087
fa = reinterpret(reshape, Float64, copy(u8a))
8188
@inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa)
89+
end
8290
end

0 commit comments

Comments
 (0)