Skip to content

Commit 7ee3b5c

Browse files
authored
Merge pull request #90 from SciML/loopvecbase
More changes to support VectorizationBase/LoopVectorization
2 parents 4d205d7 + 3236fd7 commit 7ee3b5c

File tree

6 files changed

+158
-18
lines changed

6 files changed

+158
-18
lines changed

.github/workflows/ci.yml

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
name: CI
2+
on:
3+
pull_request:
4+
branches:
5+
- master
6+
push:
7+
branches:
8+
- master
9+
tags: '*'
10+
jobs:
11+
test:
12+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
13+
runs-on: ${{ matrix.os }}
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
version:
18+
- '1.2'
19+
- '1'
20+
- 'nightly'
21+
os:
22+
- ubuntu-latest
23+
arch:
24+
- x64
25+
steps:
26+
- uses: actions/checkout@v2
27+
- uses: julia-actions/setup-julia@v1
28+
with:
29+
version: ${{ matrix.version }}
30+
arch: ${{ matrix.arch }}
31+
- uses: actions/cache@v1
32+
env:
33+
cache-name: cache-artifacts
34+
with:
35+
path: ~/.julia/artifacts
36+
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
37+
restore-keys: |
38+
${{ runner.os }}-test-${{ env.cache-name }}-
39+
${{ runner.os }}-test-
40+
${{ runner.os }}-
41+
- uses: julia-actions/julia-buildpkg@v1
42+
- uses: julia-actions/julia-runtest@v1
43+
- uses: julia-actions/julia-processcoverage@v1
44+
- uses: codecov/codecov-action@v1
45+
with:
46+
file: lcov.info
47+
docs:
48+
name: Documentation
49+
runs-on: ubuntu-latest
50+
steps:
51+
- uses: actions/checkout@v2
52+
- uses: julia-actions/setup-julia@v1
53+
with:
54+
version: '1'
55+
- run: |
56+
julia --project=docs -e '
57+
using Pkg
58+
Pkg.develop(PackageSpec(path=pwd()))
59+
Pkg.instantiate()'
60+
# - run: |
61+
# julia --project=docs -e '
62+
# using Documenter: doctest
63+
# using ArrayInterface
64+
# doctest(ArrayInterface)'
65+
# - run: julia --project=docs docs/make.jl
66+
env:
67+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
68+
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
69+

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.13.8"
3+
version = "2.14.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/ArrayInterface.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
2525
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
2626
parent_type(::Type{Slice{T}}) where {T} = T
2727
parent_type(::Type{T}) where {T} = T
28+
parent_type(::Type{R}) where {S, T, A <: AbstractArray{S}, N, R <: Base.ReinterpretArray{T, N, S, A}} = A
2829

2930
"""
3031
known_length(::Type{T})
@@ -880,10 +881,11 @@ function __init__()
880881
size(A::OffsetArrays.OffsetArray) = size(parent(A))
881882
strides(A::OffsetArrays.OffsetArray) = strides(parent(A))
882883
# offsets(A::OffsetArrays.OffsetArray) = map(+, A.offsets, offsets(parent(A)))
883-
device(::OffsetArrays.OffsetArray) = CheckParent()
884-
contiguous_axis(A::OffsetArrays.OffsetArray) = contiguous_axis(parent(A))
885-
contiguous_batch_size(A::OffsetArrays.OffsetArray) = contiguous_batch_size(parent(A))
886-
stride_rank(A::OffsetArrays.OffsetArray) = stride_rank(parent(A))
884+
parent_type(::Type{O}) where {T,N,A<:AbstractArray{T,N},O<:OffsetArrays.OffsetArray{T,N,A}} = A
885+
device(::Type{<:OffsetArrays.OffsetArray}) = CheckParent()
886+
contiguous_axis(::Type{A}) where {A <: OffsetArrays.OffsetArray} = contiguous_axis(parent_type(A))
887+
contiguous_batch_size(::Type{A}) where {A <: OffsetArrays.OffsetArray} = contiguous_batch_size(parent_type(A))
888+
stride_rank(::Type{A}) where {A <: OffsetArrays.OffsetArray} = stride_rank(parent_type(A))
887889
end
888890
end
889891

src/dimensions.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,20 @@ Return the names of the dimensions for `x`.
3030
end
3131
end
3232
@inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}}
33-
return _transpose_dimnames(dimnames(parent_type(T)))
33+
return _transpose_dimnames(Val(dimnames(parent_type(T))))
3434
end
35-
_transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x))
36-
_transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x))
35+
# inserting the Val here seems to help inferability; I got a test failure without it.
36+
function _transpose_dimnames(::Val{S}) where {S}
37+
if length(S) == 1
38+
(:_, first(S))
39+
elseif length(S) == 2
40+
(last(S), first(S))
41+
else
42+
throw("Can't transpose $S of dim $(length(S)).")
43+
end
44+
end
45+
@inline _transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x))
46+
@inline _transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x))
3747

3848
@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}}
3949
return map(i -> dimnames(parent_type(T), i), I)
@@ -143,6 +153,8 @@ julia> ArrayInterface.size(A)
143153
"""
144154
size(A) = Base.size(A)
145155
size(A, d) = Base.size(A, to_dims(A, d))
156+
size(x::LinearAlgebra.Adjoint{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
157+
size(x::LinearAlgebra.Transpose{T,V}) where {T, V <: AbstractVector{T}} = (One(), static_length(x))
146158

147159
"""
148160
axes(A, d)

src/static.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@ Base.Integer(x::StaticInt{N}) where {N} = x
2222
(::Type{T})(x::StaticInt{N}) where {T<:Integer,N} = T(N)
2323
(::Type{T})(x::Int) where {T<:StaticInt} = StaticInt(x)
2424
Base.convert(::Type{StaticInt{N}}, ::StaticInt{N}) where {N} = StaticInt{N}()
25+
Base.float(::StaticInt{N}) where {N} = Float64(N)
2526

26-
Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T <: AbstractIrrational} = promote_rule(Int, T)
27-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T <: AbstractIrrational} = promote_rule(T, Int)
27+
Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T <: Number} = promote_type(Int, T)
28+
Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T <: AbstractIrrational} = promote_type(Int, T)
29+
# Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T <: AbstractIrrational} = promote_rule(T, Int)
2830
for (S,T) [(:Complex,:Real), (:Rational, :Integer), (:(Base.TwicePrecision),:Any)]
29-
@eval Base.promote_rule(::Type{$S{T}}, ::Type{<:StaticInt}) where {T <: $T} = promote_rule($S{T}, Int)
31+
@eval Base.promote_rule(::Type{$S{T}}, ::Type{<:StaticInt}) where {T <: $T} = promote_type($S{T}, Int)
3032
end
3133
Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:StaticInt}) = Union{Nothing, Missing, Int}
32-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int)
33-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Nothing} = promote_rule(T, Int)
34-
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Missing} = promote_rule(T, Int)
34+
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Union{Missing,Nothing}} = promote_type(T, Int)
35+
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Nothing} = promote_type(T, Int)
36+
Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Missing} = promote_type(T, Int)
3537
for T [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any]
3638
# let S = :Any
3739
@eval begin
38-
Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: StaticInt} = promote_rule(Int, $T)
39-
Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: StaticInt} = promote_rule($T, Int)
40+
Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: StaticInt} = promote_type(Int, $T)
41+
Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: StaticInt} = promote_type($T, Int)
4042
end
4143
end
4244
Base.promote_rule(::Type{<:StaticInt}, ::Type{<:StaticInt}) = Int
@@ -85,6 +87,12 @@ end
8587
for f [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :()]
8688
@eval @generated Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = Expr(:call, Expr(:curly, :StaticInt, $f(M, N)))
8789
end
90+
for f [:(<<), :(>>), :(>>>)]
91+
@eval begin
92+
@inline Base.$f(::StaticInt{M}, x::UInt) where {M} = $f(M, x)
93+
@inline Base.$f(x::Integer, ::StaticInt{M}) where {M} = $f(x, M)
94+
end
95+
end
8896
for f [:(==), :(!=), :(<), :(), :(>), :()]
8997
@eval begin
9098
@inline Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = $f(M, N)

src/stridelayout.jl

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ _contiguous_axis(::Any, ::Nothing) = nothing
6565
Expr(:call, Expr(:curly, :Contiguous, new_contig))
6666
end
6767

68+
# contiguous_if_one(::Contiguous{1}) = Contiguous{1}()
69+
# contiguous_if_one(::Any) = Contiguous{-1}()
70+
function contiguous_axis(::Type{R}) where {T, N, S, A <: Array{S}, R <: Base.ReinterpretArray{T, N, S, A}}
71+
isbitstype(S) ? Contiguous{1}() : nothing
72+
# contiguous_if_one(contiguous_axis(parent_type(R)))
73+
end
74+
75+
6876
"""
6977
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
7078
@@ -108,6 +116,8 @@ _contiguous_batch_size(::Any, ::Any, ::Any) = nothing
108116
end
109117
end
110118

119+
contiguous_batch_size(::Type{R}) where {T, N, S, A <: Array{S}, R <: Base.ReinterpretArray{T, N, S, A}} = ContiguousBatch{0}()
120+
111121
struct StrideRank{R} end
112122
Base.@pure StrideRank(R::NTuple{<:Any,Int}) = StrideRank{R}()
113123
_get(::StrideRank{R}) where {R} = R
@@ -164,6 +174,7 @@ _stride_rank(::Any, ::Any) = nothing
164174
Expr(:call, Expr(:curly, :StrideRank, ranktup))
165175
end
166176
stride_rank(x, i) = stride_rank(x)[i]
177+
stride_rank(::Type{R}) where {T, N, S, A <: Array{S}, R <: Base.ReinterpretArray{T, N, S, A}} = StrideRank{ntuple(identity, Val{N}())}()
167178

168179
"""
169180
is_column_major(A) -> Val{true/false}()
@@ -248,6 +259,16 @@ julia> A = rand(3,4);
248259
249260
julia> ArrayInterface.strides(A)
250261
(StaticInt{1}(), 3)
262+
263+
Additionally, the behavior differs from `Base.strides` for adjoint vectors:
264+
265+
julia> x = rand(5);
266+
267+
julia> ArrayInterface.strides(x')
268+
(StaticInt{1}(), StaticInt{1}())
269+
270+
This is to support the pattern of using just the first stride for linear indexing, `x[i]`,
271+
while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`.
251272
```
252273
"""
253274
strides(A) = Base.strides(A)
@@ -264,6 +285,16 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
264285
@inline strides(A::Vector{<:Any}) = (StaticInt(1),)
265286
@inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...)
266287
@inline strides(A::AbstractArray) = _strides(A, Base.strides(A), contiguous_axis(A))
288+
289+
@inline function strides(x::LinearAlgebra.Adjoint{T,V}) where {T, V <: AbstractVector{T}}
290+
strd = stride(parent(x), One())
291+
(strd, strd)
292+
end
293+
@inline function strides(x::LinearAlgebra.Transpose{T,V}) where {T, V <: AbstractVector{T}}
294+
strd = stride(parent(x), One())
295+
(strd, strd)
296+
end
297+
267298
@generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::Contiguous{C}) where {T,N,C}
268299
if C 0 || C > N
269300
return Expr(:block, Expr(:meta,:inline), :s)
@@ -282,6 +313,22 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
282313
end
283314
end
284315

316+
if VERSION v"1.6.0-DEV.1581"
317+
@generated function _strides(_::Base.ReinterpretArray{T, N, S, A, true}, s::NTuple{N}, ::Contiguous{1}) where {T, N, S, D, A <: Array{S,D}}
318+
stup = Expr(:tuple, :(One()))
319+
if D < N
320+
push!(stup.args, Expr(:call, Expr(:curly, :StaticInt, sizeof(S) ÷ sizeof(T))))
321+
end
322+
for n 2+(D < N):N
323+
push!(stup.args, Expr(:ref, :s, n))
324+
end
325+
quote
326+
$(Expr(:meta,:inline))
327+
@inbounds $stup
328+
end
329+
end
330+
end
331+
285332
@inline function offsets(x, i)
286333
inds = indices(x, i)
287334
start = known_first(inds)
@@ -304,7 +351,7 @@ end
304351
@inline strides(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = permute(strides(parent(B)), Val{I1}())
305352
@inline stride(A::AbstractArray, ::StaticInt{N}) where {N} = strides(A)[N]
306353
@inline stride(A::AbstractArray, ::Val{N}) where {N} = strides(A)[N]
307-
stride(A, i) = Base.stride(A, i)
354+
stride(A, i) = Base.stride(A, i) # for type stability
308355

309356
size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} = _size(size(parent(B)), B.indices, map(static_length, B.indices))
310357
strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} = _strides(strides(parent(B)), B.indices)
@@ -324,8 +371,10 @@ end
324371
@generated function _strides(A::Tuple{Vararg{Any,N}}, inds::I) where {N, I<:Tuple}
325372
t = Expr(:tuple)
326373
for n in 1:N
327-
if I.parameters[n] <: AbstractRange
374+
if I.parameters[n] <: AbstractUnitRange
328375
push!(t.args, Expr(:ref, :A, n))
376+
elseif I.parameters[n] <: AbstractRange
377+
push!(t.args, Expr(:call, :(*), Expr(:ref, :A, n), Expr(:call, :static_step, Expr(:ref, :inds, n))))
329378
elseif !(I.parameters[n] <: Integer)
330379
return nothing
331380
end

0 commit comments

Comments
 (0)