Skip to content

Commit 19bfb47

Browse files
authored
Merge branch 'master' into ChrisRackauckas-patch-1
2 parents 6d760a2 + b922313 commit 19bfb47

File tree

4 files changed

+103
-74
lines changed

4 files changed

+103
-74
lines changed

lib/ArrayInterfaceCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterfaceCore"
22
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
3-
version = "0.1.11"
3+
version = "0.1.12"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@ using LinearAlgebra: AbstractTriangular
55
using SparseArrays
66
using SuiteSparse
77

8-
@static if isdefined(Base, :ReshapedReinterpretArray)
9-
_is_reshaped(::Type{<:Base.ReshapedReinterpretArray}) = true
10-
end
11-
_is_reshaped(::Type{<:Base.ReinterpretArray}) = false
12-
138
@static if isdefined(Base, Symbol("@assume_effects"))
149
using Base: @assume_effects
1510
else
@@ -562,6 +557,28 @@ ndims_shape(@nospecialize T::Type{<:Number}) = 0
562557
ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T)
563558
ndims_shape(x) = ndims_shape(typeof(x))
564559

560+
"""
561+
IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{NI,NS,IS}()
562+
563+
Provides basic trait information for each index type in in the tuple `T`. `NI`, `NS`, and
564+
`IS` are tuples of [`ndims_index`](@ref), [`ndims_shape`](@ref), and
565+
[`is_splat_index`](@ref) (respectively) for each field of `T`.
566+
"""
567+
struct IndicesInfo{NI,NS,IS} end
568+
IndicesInfo(@nospecialize x::Tuple) = IndicesInfo(typeof(x))
569+
@generated function IndicesInfo(::Type{T}) where {T<:Tuple}
570+
NI = Expr(:tuple)
571+
NS = Expr(:tuple)
572+
IS = Expr(:tuple)
573+
for i in 1:fieldcount(T)
574+
T_i = fieldtype(T, i)
575+
push!(NI.args, :(ndims_index($(T_i))))
576+
push!(NS.args, :(ndims_shape($(T_i))))
577+
push!(IS.args, :(is_splat_index($(T_i))))
578+
end
579+
Expr(:block, Expr(:meta, :inline), :(IndicesInfo{$(NI),$(NS),$(IS)}()))
580+
end
581+
565582
"""
566583
instances_do_not_alias(::Type{T}) -> Bool
567584

src/ArrayInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using ArrayInterfaceCore
44
import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buffer,
55
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
7-
safevec, zeromatrix, ColoringAlgorithm,
8-
fast_scalar_indexing, parameterless_type, ndims_index, is_splat_index, is_forwarding_wrapper
7+
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
8+
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo
99

1010
# ArrayIndex subtypes and methods
1111
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex

src/indexing.jl

Lines changed: 78 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,6 @@ function Base.last(x::AbstractVector, n::StaticInt)
2222
@inbounds x[max(offset1(x), (stop + one(stop)) - n):stop]
2323
end
2424

25-
function _is_splat(::Type{I}, i::StaticInt) where {I}
26-
if dynamic(is_splat_index(field_type(I, i)))
27-
True()
28-
else
29-
False()
30-
end
31-
end
32-
33-
_ndims_index(::Type{I}, i::StaticInt) where {I} = StaticInt(ndims_index(field_type(I, i)))
34-
3525
"""
3626
to_indices(A, I::Tuple) -> Tuple
3727
@@ -91,69 +81,92 @@ This implementation differs from that of `Base.to_indices` in the following ways
9181
"""
9282
to_indices(A, ::Tuple{}) = ()
9383
@inline function to_indices(a::A, inds::I) where {A,I}
94-
_to_indices(
95-
a,
96-
inds,
97-
IndexStyle(A),
98-
static(ndims(A)),
99-
eachop(_ndims_index, ntuple(static, StaticInt(known_length(I))), I),
100-
eachop(_is_splat, ntuple(static, StaticInt(known_length(I))), I)
101-
)
102-
end
103-
@generated function _to_indices(A, inds::I, ::S, ::StaticInt{N}, ::NDI, ::IS) where {I,S,N,NDI,IS}
104-
cnt = zeros(Int, known_length(NDI))
105-
splat_position = 0
106-
remaining = N
107-
for i in 1:known_length(NDI)
108-
ndi = known(NDI.parameters[i])
109-
splat = known(IS.parameters[i])
110-
if splat && splat_position === 0
111-
splat_position = i
112-
else
113-
remaining -= ndi
114-
cnt[i] = ndi
115-
end
116-
end
117-
if splat_position !== 0
118-
cnt[splat_position] = max(0, remaining)
84+
_to_indices(a, inds, IndexStyle(A), static(ndims(A)), IndicesInfo(I))
85+
end
86+
@generated function _to_indices(a, inds, ::S, ::StaticInt{N}, ::IndicesInfo{NI,NS,IS}) where {S,N,NI,NS,IS}
87+
_to_indices_expr(S, N, NI, NS, IS)
88+
end
89+
function _to_indices_expr(S::DataType, N::Int, ni, ns, is)
90+
blk = Expr(:block, Expr(:meta, :inline))
91+
# check to see if we are dealing with linear indexing over a multidimensional array
92+
if length(ni) == 1 && ni[1] === 1
93+
push!(blk.args, :((to_index(LazyAxis{:}(a), getfield(inds, 1)),)))
11994
else
120-
# if there are additional trailing dimensions not consumed by the index then we have
121-
# to assume it's linear indexing or that these are trailing dimensions.
122-
cnt[end] += max(0, remaining)
123-
end
95+
indsexpr = Expr(:tuple)
96+
ndi = Int[]
97+
nds = Int[]
98+
isi = Bool[]
99+
# 1. unwrap AbstractCartesianIndex, CartesianIndices, Indices
100+
for i in 1:length(ns)
101+
ns_i = ns[i]
102+
if ns_i isa Tuple
103+
for j in 1:length(ns_i)
104+
push!(ndi, 1)
105+
push!(nds, ns_i[j])
106+
push!(isi, false)
107+
push!(indsexpr.args, :(getfield(getfield(getfield(inds, $i), 1), $j)))
108+
end
109+
else
110+
push!(indsexpr.args, :(getfield(inds, $i)))
111+
push!(ndi, ni[i])
112+
push!(nds, ns_i)
113+
push!(isi, is[i])
114+
end
115+
end
124116

125-
t = Expr(:tuple)
126-
dim = 0
127-
for i in 1:known_length(NDI)
128-
if i === known_length(NDI) && S <: IndexLinear
129-
ICall = :LinearIndices
130-
else
131-
ICall = :CartesianIndices
117+
# 2. find splat indices
118+
splat_position = 0
119+
remaining = N
120+
for i in eachindex(ndi, nds, isi)
121+
if isi[i] && splat_position == 0
122+
splat_position = i
123+
else
124+
remaining -= ndi[i]
125+
end
132126
end
133-
c = cnt[i]
134-
iexpr = :(@inbounds(getfield(inds, $i))::$(I.parameters[i]))
135-
if dim === N
136-
push!(t.args, :(to_index($(ICall)(()), $iexpr)))
137-
elseif c === 1
138-
dim += 1
139-
push!(t.args, :(to_index(@inbounds(getfield(axs, $dim)), $iexpr)))
140-
else
141-
subaxs = Expr(:tuple)
142-
for _ in 1:c
143-
if dim < N
127+
if splat_position !== 0
128+
for _ in 2:remaining
129+
insert!(ndi, splat_position, 1)
130+
insert!(nds, splat_position, 1)
131+
insert!(indsexpr.args, splat_position, indsexpr.args[splat_position])
132+
end
133+
end
134+
135+
# 3. insert `to_index` calls
136+
dim = 0
137+
nndi = length(ndi)
138+
for i in 1:nndi
139+
ndi_i = ndi[i]
140+
if ndi_i == 1
141+
dim += 1
142+
indsexpr.args[i] = :(to_index($(_axis_expr(N, dim)), $(indsexpr.args[i])))
143+
else
144+
subaxs = Expr(:tuple)
145+
for _ in 1:ndi_i
144146
dim += 1
145-
push!(subaxs.args, :(@inbounds(getfield(axs, $dim))))
147+
push!(subaxs.args, _axis_expr(N, dim))
148+
end
149+
if i == nndi && S <: IndexLinear
150+
indsexpr.args[i] = :(to_index(LinearIndices($(subaxs)), $(indsexpr.args[i])))
151+
else
152+
indsexpr.args[i] = :(to_index(CartesianIndices($(subaxs)), $(indsexpr.args[i])))
146153
end
147154
end
148-
push!(t.args, :(to_index($(ICall)($subaxs), $iexpr)))
149155
end
156+
push!(blk.args, Expr(:(=), :axs, :(lazy_axes(a))))
157+
push!(blk.args, :(_flatten_tuples($(indsexpr))))
158+
end
159+
return blk
160+
end
161+
162+
function _axis_expr(N::Int, d::Int)
163+
if d <= N
164+
:(getfield(axs, $d))
165+
else # ndims(a)+ can only have indices 1:1
166+
:($(SOneTo(1)))
150167
end
151-
Expr(:block,
152-
Expr(:meta, :inline),
153-
Expr(:(=), :axs, :(lazy_axes(A))),
154-
:(_flatten_tuples($t))
155-
)
156168
end
169+
157170
@generated function _flatten_tuples(inds::I) where {I}
158171
t = Expr(:tuple)
159172
for i in 1:known_length(I)
@@ -409,7 +422,7 @@ _output_shape(x::AbstractRange) = (Base.length(x),)
409422
end
410423
_known_first_isone(ind) = known_first(ind) !== nothing && isone(known_first(ind))
411424
@inline function unsafe_get_collection(A::LinearIndices{N}, inds) where {N}
412-
if Base.length(inds) === 1 && isone(_ndims_index(typeof(inds), static(1)))
425+
if Base.length(inds) === 1 && ndims_index(typeof(first(inds))) === 1
413426
return @inbounds(eachindex(A)[first(inds)])
414427
elseif stride_preserving_index(typeof(inds)) === True() &&
415428
reduce_tup(&, map(_known_first_isone, inds))
@@ -464,7 +477,6 @@ function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) wh
464477
end
465478
end
466479

467-
468480
function unsafe_setindex!(A::Array{T}, v) where {T}
469481
Base.arrayset(false, A, convert(T, v)::T, 1)
470482
end

0 commit comments

Comments
 (0)