Skip to content

Commit 8dcdca8

Browse files
committed
If a LowDimArray doesn't have true/false indicated, check size for stride.
1 parent 3d569d6 commit 8dcdca8

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/broadcast.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,19 @@ struct LowDimArray{D,T,N,A<:DenseArray{T,N}} <: DenseArray{T,N}
116116
end
117117
@inline Base.pointer(A::LowDimArray) = pointer(A.data)
118118
Base.@propagate_inbounds Base.getindex(A::LowDimArray, i...) = getindex(A.data, i...)
119-
Base.size(A::LowDimArray) = Base.size(A.data)
119+
@inline Base.size(A::LowDimArray) = Base.size(A.data)
120+
@inline Base.size(A::LowDimArray, i) = Base.size(A.data, i)
120121
@generated function VectorizationBase.stridedpointer(A::LowDimArray{D,T,N}) where {D,T,N}
121122
smul = Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:VectorizationBase)), QuoteNode(:staticmul))
122-
s = Expr(:call, smul, T, Expr(:tuple, [Expr(:ref, :strideA, n) for n 1+D[1]:N if ((length(D) < n) || D[n])]...))
123+
multup = Expr(:tuple)
124+
for n 1:N
125+
if length(D) < n
126+
push!(multup.args, Expr(:call, :ifelse, :(isone(size(A,$n))), 0, Expr(:ref, :strideA, n)))
127+
elseif D[n]
128+
push!(multup.args, Expr(:ref, :strideA, n))
129+
end
130+
end
131+
s = Expr(:call, smul, T, multup)
123132
f = D[1] ? :PackedStridedPointer : :SparseStridedPointer
124133
Expr(:block, Expr(:meta,:inline), Expr(:(=), :strideA, Expr(:call, :strides, Expr(:(.), :A, QuoteNode(:data)))),
125134
Expr(:call, Expr(:(.), :VectorizationBase, QuoteNode(f)), Expr(:call, :pointer, Expr(:(.), :A, QuoteNode(:data))), s))

test/broadcast.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
fill!(c2, 99999);
2929
@avx @. c2 = a + bl;
3030
@test c1 c2
31+
br = reshape(rand(99), (1,99,1));
32+
bl = LowDimArray{(false,)}(br);
33+
@. c1 = a + br;
34+
fill!(c2, 99999);
35+
@avx @. c2 = a + bl;
36+
@test c1 c2
3137

3238
xs = rand(T, M);
3339
max_ = maximum(xs, dims=1)

0 commit comments

Comments
 (0)