Skip to content

Commit fb4a44b

Browse files
authored
views should preserve static axes information (#295)
* views should preserve static axes information
1 parent a61ed56 commit fb4a44b

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.10"
3+
version = "6.0.11"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/axes.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ axes(A::ReshapedArray) = Base.axes(A)
103103
axes(A::PermutedDimsArray) = permute(axes(parent(A)), to_parent_dims(A))
104104
axes(A::MatAdjTrans) = permute(axes(parent(A)), to_parent_dims(A))
105105
axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1))
106-
axes(A::SubArray) = map(Base.axes1, permute(A.indices, to_parent_dims(A)))
106+
axes(A::SubArray{<:Any,N}) where {N} = map(Base.Fix1(axes, A), Static.nstatic(Val(N)))
107107

108108
@inline axes(A, dim) = _axes(A, to_dims(A, dim))
109109
@inline function _axes(A, dim::Int)
@@ -120,19 +120,24 @@ end
120120
return getfield(axes(A), Int(dim))
121121
end
122122
end
123-
123+
@inline function _inbound_axes(A::SubArray, dim)
124+
pd = to_parent_dims(A, to_dims(A, dim))
125+
ax = getindex(A.indices, pd)
126+
ax isa Base.Slice || return axes(ax, 1)
127+
axes(parent(A))[pd]
128+
end
124129
@inline function axes(A::SubArray, dim::CanonicalInt)
125130
if dim > ndims(A)
126131
return OneTo(1)
127132
else
128-
return axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1)
133+
return _inbound_axes(A, dim)
129134
end
130135
end
131136
@inline function axes(A::SubArray, ::StaticInt{dim}) where {dim}
132137
if dim > ndims(A)
133138
return SOneTo{1}()
134139
else
135-
return axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1)
140+
return _inbound_axes(A, StaticInt(dim))
136141
end
137142
end
138143

test/axes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ end
7171
@testset "SubArray Adjoint Axis" begin
7272
N = 4; d = rand(N);
7373

74-
@test @inferred(ArrayInterface.axes_types(typeof(view(d',:,1:2)))) === Tuple{ArrayInterface.OptionallyStaticUnitRange{StaticInt{1}, StaticInt{1}}, Base.OneTo{Int64}}
74+
@test @inferred(ArrayInterface.axes_types(typeof(view(d',:,1:2)))) === Tuple{ArrayInterface.OptionallyStaticUnitRange{StaticInt{1}, StaticInt{1}}, Base.OneTo{Int64}} === typeof(@inferred(ArrayInterface.axes(view(d',:,1:2)))) === typeof((ArrayInterface.axes(view(d',:,1:2),1),ArrayInterface.axes(view(d',:,1:2),2)))
7575

7676
end
7777
if isdefined(Base, :ReshapedReinterpretArray)

0 commit comments

Comments
 (0)