Skip to content

Commit a2ce880

Browse files
committed
Fix for broadcasting Float32 and assigning zero(eltype(A)) within loop body.
1 parent bba8487 commit a2ce880

File tree

5 files changed

+40
-22
lines changed

5 files changed

+40
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
## Installation
1010
```
1111
using Pkg
12-
Pkg.add(PackageSpec(url="https://github.com/chriselrod/VectorizationBase.jl"))
13-
Pkg.add(PackageSpec(url="https://github.com/chriselrod/SIMDPirates.jl"))
14-
Pkg.add(PackageSpec(url="https://github.com/chriselrod/SLEEFPirates.jl"))
15-
Pkg.add(PackageSpec(url="https://github.com/chriselrod/LoopVectorization.jl"))
12+
Pkg.add("LoopVectorization")
1613
```
1714

1815

src/broadcast.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,27 @@ function Base.Broadcast._broadcast_getindex_eltype(p::Product)
1818
)
1919
end
2020

21+
# recursive_eltype(::Type{A}) where {T, A <: AbstractArray{T}} = T
22+
# recursive_eltype(::Type{NTuple{N,T}}) where {N,T<:Union{Float32,Float64}} = T
23+
# recursive_eltype(::Type{Float32}) = Float32
24+
# recursive_eltype(::Type{Float64}) = Float64
25+
# recursive_eltype(::Type{Tuple{T}}) where {T} = T
26+
# recursive_eltype(::Type{Tuple{T1,T2}}) where {T1,T2} = promote_type(recursive_eltype(T1), recursive_eltype(T2))
27+
# recursive_eltype(::Type{Tuple{T1,T2,T3}}) where {T1,T2,T3} = promote_type(recursive_eltype(T1), recursive_eltype(T2), recursive_eltype(T3))
28+
# recursive_eltype(::Type{Tuple{T1,T2,T3,T4}}) where {T1,T2,T3,T4} = promote_type(recursive_eltype(T1), recursive_eltype(T2), recursive_eltype(T3), recursive_eltype(T4))
29+
# recursive_eltype(::Type{Tuple{T1,T2,T3,T4,T5}}) where {T1,T2,T3,T4,T5} = promote_type(recursive_eltype(T1), recursive_eltype(T2), recursive_eltype(T3), recursive_eltype(T4), recursive_eltype(T5))
30+
31+
# function recursive_eltype(::Type{Broadcasted{S,A,F,ARGS}}) where {S,A,F,ARGS}
32+
# recursive_eltype(ARGS)
33+
# end
2134

2235
@inline (a::A, b::B) where {A,B} = Product{A,B}(a, b)
2336
@inline Base.Broadcast.broadcasted(::typeof(), a::A, b::B) where {A, B} = Product{A,B}(a, b)
2437
# TODO: Need to make this handle A or B being (1 or 2)-D broadcast objects.
2538
function add_broadcast!(
2639
ls::LoopSet, mC::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
2740
::Type{Product{A,B}}, elementbytes::Int = 8
28-
) where {T,A,B}
41+
) where {A, B}
2942
K = gensym(:K)
3043
mA = gensym(:Aₘₖ)
3144
mB = gensym(:Bₖₙ)
@@ -54,7 +67,12 @@ function add_broadcast!(
5467
# load B
5568
loadB = add_broadcast!(ls, gensym(:B), mB, bloopsyms, B, elementbytes)
5669
# set Cₘₙ = 0
57-
setC = add_constant!(ls, 0.0, cloopsyms, mC, elementbytes)
70+
# setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
71+
setC = if elementbytes == 4
72+
add_constant!(ls, 0f0, cloopsyms, mC, elementbytes)
73+
else#if elementbytes == 4
74+
add_constant!(ls, 0.0, cloopsyms, mC, elementbytes)
75+
end
5876
# compute Cₘₙ += Aₘₖ * Bₖₙ
5977
reductop = Operation(
6078
ls, mC, elementbytes, :vmuladd, compute, reductdeps, Symbol[k], Operation[loadA, loadB, setC]
@@ -118,7 +136,7 @@ function add_broadcast!(
118136
argname = gensym(:arg)
119137
pushpreamble!(ls, Expr(:(=), argname, Expr(:ref, bcargs, i)))
120138
# dynamic dispatch
121-
parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg)::Operation
139+
parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg, elementbytes)::Operation
122140
pushparent!(parents, deps, reduceddeps, parent)
123141
end
124142
op = Operation(
@@ -130,7 +148,7 @@ end
130148
# size of dest determines loops
131149
@generated function vmaterialize!(
132150
dest::AbstractArray{T,N}, bc::BC
133-
) where {T, N, BC <: Broadcasted}
151+
) where {T <: Union{Float32,Float64}, N, BC <: Broadcasted}
134152
# ) where {N, T, BC <: Broadcasted}
135153
# we have an N dimensional loop.
136154
# need to construct the LoopSet
@@ -143,8 +161,9 @@ end
143161
push!(sizes.args, Nsym)
144162
end
145163
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest)))
146-
add_broadcast!(ls, :dest, :bc, loopsyms, BC)
147-
add_store!(ls, :dest, ArrayReference(:dest, loopsyms, Ref{Bool}(false)))
164+
elementbytes = sizeof(T)
165+
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
166+
add_store!(ls, :dest, ArrayReference(:dest, loopsyms, Ref{Bool}(false)), elementbytes)
148167
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
149168
q = lower(ls)
150169
push!(q.args, :dest)

src/graphs.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,13 @@ function add_operation!(
480480
if RHS.head === :ref
481481
add_load_ref!(ls, LHS, RHS, elementbytes)
482482
elseif RHS.head === :call
483-
if first(RHS.args) === :getindex
483+
f = first(RHS.args)
484+
if f === :getindex
484485
add_load_getindex!(ls, LHS, RHS, elementbytes)
486+
elseif f === :zero || f === :one
487+
c = gensym(:constant)
488+
pushpreamble!(ls, Expr(:(=), c, RHS))
489+
add_constant!(ls, c, [keys(ls.loops)...], LHS, elementbytes)
485490
else
486491
add_compute!(ls, LHS, RHS, elementbytes)
487492
end

test/runtests.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using LoopVectorization
3636

3737
@testset "GEMM" begin
3838
gemmq = :(for i 1:size(A,1), j 1:size(B,2)
39-
Cᵢⱼ = z#ero(eltype(C))
39+
Cᵢⱼ = zero(eltype(C))
4040
for k 1:size(A,2)
4141
Cᵢⱼ += A[i,k] * B[k,j]
4242
end
@@ -57,9 +57,8 @@ using LoopVectorization
5757
end
5858
end
5959
function mygemmavx!(C, A, B)
60-
z = zero(eltype(C))
6160
@avx for i 1:size(A,1), j 1:size(B,2)
62-
Cᵢⱼ = z
61+
Cᵢⱼ = zero(eltype(C))
6362
for k 1:size(A,2)
6463
Cᵢⱼ += A[i,k] * B[k,j]
6564
end
@@ -202,9 +201,8 @@ using LoopVectorization
202201
end
203202
end
204203
function mygemvavx!(y, A, x)
205-
z = zero(eltype(y))
206204
@avx for i eachindex(y)
207-
yᵢ = z
205+
yᵢ = zero(eltype(y))
208206
for j eachindex(x)
209207
yᵢ += A[i,j] * x[j]
210208
end
@@ -262,9 +260,8 @@ using LoopVectorization
262260
end
263261

264262
function mycolsumavx!(x, A)
265-
z = zero(eltype(x))
266263
@avx for j eachindex(x)
267-
xⱼ = z
264+
xⱼ = zero(eltype(x))
268265
for i 1:size(A,2)
269266
xⱼ += A[j,i]
270267
end
@@ -290,9 +287,8 @@ using LoopVectorization
290287
end
291288
end
292289
function myvaravx!(s², A, x̄)
293-
z = zero(eltype(s²))
294290
@avx for j eachindex(s²)
295-
s²ⱼ = z
291+
s²ⱼ = zero(eltype(s²))
296292
x̄ⱼ = x̄[j]
297293
for i 1:size(A,2)
298294
δ = A[j,i] - x̄ⱼ
@@ -328,7 +324,8 @@ end
328324
M, N = 37, 47
329325
# M = 77;
330326
# for T ∈ (Float32, Float64)
331-
let T = Float64
327+
for T (Float64, Float32)
328+
# let T = Float64
332329
a = rand(T, M); B = rand(T, M, N); c = rand(T, N); c′ = c';
333330

334331
d1 = @. a + B * c′;

0 commit comments

Comments
 (0)