Skip to content

Commit f957da5

Browse files
authored
Merge pull request #122 from SciML/bitarray
Add stridelayout support for `BitArray`
2 parents baccd56 + 8864fe0 commit f957da5

File tree

4 files changed

+40
-13
lines changed

4 files changed

+40
-13
lines changed

src/ArrayInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,7 @@ defines_strides(::Type{<:StridedArray}) = true
644644
function defines_strides(::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I}
645645
return stride_preserving_index(I) === True()
646646
end
647+
defines_strides(::Type{<:BitArray}) = true
647648

648649
"""
649650
can_avx(f)

src/ranges.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ Base.Slice(Static(1):100)
539539
```
540540
"""
541541
@generated function reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}
542-
q = Expr(:block, Expr(:meta, :inline))
542+
q = Expr(:block, Expr(:meta, :inline, :propagate_inbounds))
543543
if N == 1
544544
push!(q.args, :(inds[1]))
545545
return q
@@ -566,7 +566,7 @@ Base.Slice(Static(1):100)
566566
q
567567
end
568568

569-
@inline function _pick_range(x, y)
569+
@propagate_inbounds function _pick_range(x, y)
570570
fst = _try_static(static_first(x), static_first(y))
571571
lst = _try_static(static_last(x), static_last(y))
572572
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
@@ -591,19 +591,19 @@ specified, then the indices for visiting each index of `x` are returned.
591591
end
592592
@inline indices(x::AbstractUnitRange{<:Integer}) = Base.Slice(OptionallyStaticUnitRange(x))
593593

594-
function indices(x::Tuple)
594+
@propagate_inbounds function indices(x::Tuple)
595595
inds = map(eachindex, x)
596596
return reduce_tup(_pick_range, inds)
597597
end
598598

599599
@inline indices(x, d) = indices(axes(x, d))
600600

601-
@inline function indices(x::Tuple{Vararg{Any,N}}, dim) where {N}
601+
@propagate_inbounds function indices(x::Tuple{Vararg{Any,N}}, dim) where {N}
602602
inds = map(x_i -> indices(x_i, dim), x)
603603
return reduce_tup(_pick_range, inds)
604604
end
605605

606-
@inline function indices(x::Tuple{Vararg{Any,N}}, dim::Tuple{Vararg{Any,N}}) where {N}
606+
@propagate_inbounds function indices(x::Tuple{Vararg{Any,N}}, dim::Tuple{Vararg{Any,N}}) where {N}
607607
inds = map(indices, x, dim)
608608
return reduce_tup(_pick_range, inds)
609609
end

src/stridelayout.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ function contiguous_axis(::Type{T}) where {T}
4646
return contiguous_axis(parent_type(T))
4747
end
4848
end
49-
contiguous_axis(::Type{<:Array}) = StaticInt{1}()
50-
contiguous_axis(::Type{<:Tuple}) = StaticInt{1}()
49+
contiguous_axis(::Type{<:Array}) = One()
50+
contiguous_axis(::Type{<:BitArray}) = One()
51+
contiguous_axis(::Type{<:AbstractRange}) = One()
52+
contiguous_axis(::Type{<:Tuple}) = One()
5153
function contiguous_axis(::Type{T}) where {T<:VecAdjTrans}
5254
c = contiguous_axis(parent_type(T))
5355
if c === nothing
@@ -138,6 +140,8 @@ function stride_rank(::Type{T}) where {T}
138140
end
139141
end
140142
stride_rank(::Type{Array{T,N}}) where {T,N} = nstatic(Val(N))
143+
stride_rank(::Type{BitArray{N}}) where {N} = nstatic(Val(N))
144+
stride_rank(::Type{<:AbstractRange}) = (One(),)
141145
stride_rank(::Type{<:Tuple}) = (One(),)
142146

143147
stride_rank(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2), StaticInt(1))
@@ -187,8 +191,10 @@ function _contiguous_batch_size(::StaticInt{D}, ::R) where {D,R<:Tuple}
187191
end
188192
end
189193

190-
contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = StaticInt{0}()
191-
contiguous_batch_size(::Type{<:Tuple}) = StaticInt{0}()
194+
contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = Zero()
195+
contiguous_batch_size(::Type{BitArray{N}}) where {N} = Zero()
196+
contiguous_batch_size(::Type{<:AbstractRange}) = Zero()
197+
contiguous_batch_size(::Type{<:Tuple}) = Zero()
192198
function contiguous_batch_size(::Type{T}) where {T<:Union{Transpose,Adjoint}}
193199
return contiguous_batch_size(parent_type(T))
194200
end
@@ -216,6 +222,7 @@ Returns `Val{true}` if elements of `A` are stored in column major order. Otherwi
216222
is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A))
217223
is_column_major(sr::Nothing, cbs) = False()
218224
is_column_major(sr::R, cbs) where {R} = _is_column_major(sr, cbs)
225+
is_column_major(::AbstractRange) = False()
219226

220227
# cbs > 0
221228
_is_column_major(sr::R, cbs::StaticInt) where {R} = False()
@@ -239,6 +246,8 @@ end
239246
_all_dense(::Val{N}) where {N} = ntuple(_ -> True(), Val{N}())
240247

241248
dense_dims(::Type{Array{T,N}}) where {T,N} = _all_dense(Val{N}())
249+
dense_dims(::Type{BitArray{N}}) where {N} = _all_dense(Val{N}())
250+
dense_dims(::Type{<:AbstractRange}) = (True(),)
242251
dense_dims(::Type{<:Tuple}) = (True(),)
243252
function dense_dims(::Type{T}) where {T<:VecAdjTrans}
244253
dense = dense_dims(parent_type(T))
@@ -413,6 +422,7 @@ function strides(x)
413422
end
414423
#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
415424

425+
strides(::AbstractRange) = (One(),)
416426
function strides(x::VecAdjTrans)
417427
st = first(strides(parent(x)))
418428
return (st, st)
@@ -457,6 +467,8 @@ function strides(a::A, dim::Integer) where {A}
457467
end
458468
end
459469

470+
471+
460472
@inline stride(A::AbstractArray, ::StaticInt{N}) where {N} = strides(A)[N]
461473
@inline stride(A::AbstractArray, ::Val{N}) where {N} = strides(A)[N]
462474
stride(A, i) = Base.stride(A, i) # for type stability

test/runtests.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ using OffsetArrays
303303
# R = reshape(view(x, 1:100), (10,10));
304304
# A = zeros(3,4,5);
305305
A = Wrapper(reshape(view(x, 1:60), (3,4,5)))
306+
B = A .== 0;
306307
D1 = view(A, 1:2:3, :, :) # first dimension is discontiguous
307308
D2 = view(A, :, 2:2:4, :) # first dimension is contiguous
308309

@@ -312,6 +313,8 @@ using OffsetArrays
312313
@test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1)))
313314

314315
@test @inferred(device(A)) === ArrayInterface.CPUPointer()
316+
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
317+
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
315318
@test @inferred(device((1,2,3))) === ArrayInterface.CPUIndex()
316319
@test @inferred(device(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.CPUPointer()
317320
@test @inferred(device(view(A, 1, :, 2:4))) === ArrayInterface.CPUPointer()
@@ -330,6 +333,8 @@ using OffsetArrays
330333
=#
331334
@test @inferred(contiguous_axis(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(1)
332335
@test @inferred(contiguous_axis(A)) === ArrayInterface.StaticInt(1)
336+
@test @inferred(contiguous_axis(B)) === ArrayInterface.StaticInt(1)
337+
@test @inferred(contiguous_axis(-1:19)) === ArrayInterface.StaticInt(1)
333338
@test @inferred(contiguous_axis(D1)) === ArrayInterface.StaticInt(-1)
334339
@test @inferred(contiguous_axis(D2)) === ArrayInterface.StaticInt(1)
335340
@test @inferred(contiguous_axis(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.StaticInt(2)
@@ -350,6 +355,8 @@ using OffsetArrays
350355

351356
@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
352357
@test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false)
358+
@test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false)
359+
@test @inferred(ArrayInterface.contiguous_axis_indicator(-1:10)) == (true,)
353360
@test @inferred(ArrayInterface.contiguous_axis_indicator(PermutedDimsArray(A,(3,1,2)))) == (false,true,false)
354361
@test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == (true,false)
355362
@test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == (false,true)
@@ -362,6 +369,8 @@ using OffsetArrays
362369

363370
@test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(0)
364371
@test @inferred(contiguous_batch_size(A)) === ArrayInterface.StaticInt(0)
372+
@test @inferred(contiguous_batch_size(B)) === ArrayInterface.StaticInt(0)
373+
@test @inferred(contiguous_batch_size(-1:18)) === ArrayInterface.StaticInt(0)
365374
@test @inferred(contiguous_batch_size(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.StaticInt(0)
366375
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.StaticInt(0)
367376
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.StaticInt(0)
@@ -372,6 +381,8 @@ using OffsetArrays
372381

373382
@test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == (1, 2, 3)
374383
@test @inferred(stride_rank(A)) == (1,2,3)
384+
@test @inferred(stride_rank(B)) == (1,2,3)
385+
@test @inferred(stride_rank(-4:4)) == (1,)
375386
@test @inferred(stride_rank(view(A,:,:,1))) === (static(1), static(2))
376387
@test @inferred(stride_rank(view(A,:,:,1))) === ((ArrayInterface.StaticInt(1),ArrayInterface.StaticInt(2)))
377388
@test @inferred(stride_rank(PermutedDimsArray(A,(3,1,2)))) == (3, 1, 2)
@@ -400,6 +411,8 @@ using OffsetArrays
400411

401412
@test @inferred(ArrayInterface.is_column_major(@SArray(zeros(2,2,2)))) === True()
402413
@test @inferred(ArrayInterface.is_column_major(A)) === True()
414+
@test @inferred(ArrayInterface.is_column_major(B)) === True()
415+
@test @inferred(ArrayInterface.is_column_major(-4:7)) === False()
403416
@test @inferred(ArrayInterface.is_column_major(PermutedDimsArray(A,(3,1,2)))) === False()
404417
@test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === True()
405418
@test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === False()
@@ -408,11 +421,12 @@ using OffsetArrays
408421
@test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === True()
409422
@test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === True()
410423
@test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) === False()
411-
@test @inferred(ArrayInterface.is_column_major(1:10)) === False()
412424
@test @inferred(ArrayInterface.is_column_major(2.3)) === False()
413425

414426
@test @inferred(dense_dims(@SArray(zeros(2,2,2)))) == (true,true,true)
415427
@test @inferred(dense_dims(A)) == (true,true,true)
428+
@test @inferred(dense_dims(B)) == (true,true,true)
429+
@test @inferred(dense_dims(-3:9)) == (true,)
416430
@test @inferred(dense_dims(PermutedDimsArray(A,(3,1,2)))) == (true,true,true)
417431
@test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == (true,false)
418432
@test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == (false,true)
@@ -434,9 +448,9 @@ using OffsetArrays
434448
@test @inferred(dense_dims(view(DummyZeros(3,4), :, 1))) === nothing
435449
@test @inferred(dense_dims(view(DummyZeros(3,4), :, 1)')) === nothing
436450

437-
B = Array{Int8}(undef, 2,2,2,2);
438-
doubleperm = PermutedDimsArray(PermutedDimsArray(B,(4,2,3,1)), (4,2,1,3));
439-
@test collect(strides(B))[collect(stride_rank(doubleperm))] == collect(strides(doubleperm))
451+
C = Array{Int8}(undef, 2,2,2,2);
452+
doubleperm = PermutedDimsArray(PermutedDimsArray(C,(4,2,3,1)), (4,2,1,3));
453+
@test collect(strides(C))[collect(stride_rank(doubleperm))] == collect(strides(doubleperm))
440454

441455
@test @inferred(ArrayInterface.indices(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173),1)) === Base.Slice(ArrayInterface.OptionallyStaticUnitRange(4,6))
442456
@test @inferred(ArrayInterface.indices(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173),2)) === Base.Slice(ArrayInterface.OptionallyStaticUnitRange(-172,-170))

0 commit comments

Comments
 (0)