Skip to content

Commit 97165ad

Browse files
committed
Updates to a test for new stridedpointer api
1 parent e4fa49e commit 97165ad

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ export LowDimArray, stridedpointer,
3838
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!,
3939
vfilter, vfilter!, vmapreduce, vreduce
4040

41+
@inline unwrap(::Val{N}) where {N} = N
42+
@inline unwrap(::Static{N}) where {N} = N
4143

4244
const VECTORWIDTHSYMBOL, ELTYPESYMBOL = Symbol("##Wvecwidth##"), Symbol("##Tloopeltype##")
4345

test/offsetarrays.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using LoopVectorization, OffsetArrays, Test
2-
using LoopVectorization.VectorizationBase: StaticUnitRange
1+
using LoopVectorization, ArrayInterface, OffsetArrays, Test
2+
using LoopVectorization: Static
33
# T = Float64
44
# T = Float32
55

@@ -91,17 +91,25 @@ using LoopVectorization.VectorizationBase: StaticUnitRange
9191
data::Matrix{T}
9292
end
9393
Base.size(::SizedOffsetMatrix{<:Any,LR,UR,LC,UC}) where {LR,UR,LC,UC} = (UR-LR+1,UC-LC+1)
94-
Base.axes(::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} = (StaticUnitRange{LR,UR}(),StaticUnitRange{LC,UC}())
94+
Base.axes(::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} = (Static{LR}():Static{UR}(),Static{LC}():Static{UC}())
9595
Base.parent(A::SizedOffsetMatrix) = A.data
96-
@generated function LoopVectorization.stridedpointer(A::SizedOffsetMatrix{T,LR,UR,LC,RC}) where {T,LR,UR,LC,RC}
97-
quote
98-
$(Expr(:meta,:inline))
99-
LoopVectorization.OffsetStridedPointer(
100-
LoopVectorization.StaticStridedPointer{$T,Tuple{1,$(UR-LR+1)}}(pointer(parent(A))),
101-
($(LR-1), $(LC-1))
102-
)
103-
end
96+
Base.unsafe_convert(::Type{Ptr{T}}, A::SizedOffsetMatrix{T}) where {T} = pointer(A.data)
97+
ArrayInterface.contiguous_axis(A::SizedOffsetMatrix) = ArrayInterface.Contiguous{1}()
98+
ArrayInterface.contiguous_batch_size(A::SizedOffsetMatrix) = ArrayInterface.ContiguousBatch{0}()
99+
ArrayInterface.stride_rank(A::SizedOffsetMatrix) = ArrayInterface.StrideRank{(1,2)}()
100+
function ArrayInterface.strides(A::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC}
101+
(Static{1}(), (Static{UR}() - Static{LR}() + Static{1}()))
104102
end
103+
ArrayInterface.offsets(A::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} = (Static{LR}(), Static{LC}())
104+
# @generated function LoopVectorization.stridedpointer(A::SizedOffsetMatrix{T,LR,UR,LC,RC}) where {T,LR,UR,LC,RC}
105+
# quote
106+
# $(Expr(:meta,:inline))
107+
# LoopVectorization.OffsetStridedPointer(
108+
# LoopVectorization.StaticStridedPointer{$T,Tuple{1,$(UR-LR+1)}}(pointer(parent(A))),
109+
# ($(LR-1), $(LC-1))
110+
# )
111+
# end
112+
# end
105113
# Base.size(A::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} = (1 + UR-LR, 1 + UC-LC)
106114
# Base.CartesianIndices(::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} = CartesianIndices((LR:UR,LC:UC))
107115
Base.getindex(A::SizedOffsetMatrix, i, j) = LoopVectorization.vload(LoopVectorization.stridedpointer(A), (i-1,j-1))

0 commit comments

Comments
 (0)