Skip to content

Commit fa0505b

Browse files
committed
Updates to fix transposed/adjoint broadcasting.
1 parent c68de2b commit fa0505b

File tree

7 files changed

+42
-17
lines changed

7 files changed

+42
-17
lines changed

Manifest.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,13 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6161

6262
[[SIMDPirates]]
6363
deps = ["MacroTools", "VectorizationBase"]
64-
git-tree-sha1 = "296cae2ccd6e4766aad669e748c1248fb99ab69c"
64+
git-tree-sha1 = "5f7fc8a48d9806817bb3f8a2ef9793398744e8a6"
6565
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
66-
version = "0.1.0"
66+
version = "0.1.1"
6767

6868
[[SLEEFPirates]]
6969
deps = ["SIMDPirates", "VectorizationBase"]
7070
git-tree-sha1 = "1c5b6827da87a12bdb7a4c893f44c3adbce3389d"
71-
repo-rev = "master"
72-
repo-url = "https://github.com/chriselrod/SLEEFPirates.jl"
7371
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
7472
version = "0.1.1"
7573

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1414
[compat]
1515
MacroTools = "0.5"
1616
Parameters = "0.12.0"
17-
SIMDPirates = "0.1.0"
17+
SIMDPirates = "0.1.1"
1818
SLEEFPirates = "0.1.1"
1919
VectorizationBase = "0.1.3"
2020
julia = "1.3.0"

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ In the future, I would like it to also model the cost of memory movement in the
115115
Until then, performance will degrade rapidly compared to BLAS as the size of the matrices increase. The advantage of the `@avx` macro, however, is that it is general. Not every operation is supported by BLAS.
116116

117117
For example, what if `A` were the outter product of two vectors?
118-
```julia
118+
<!-- ```julia -->
119119

120120

121-
```
121+
<!-- ``` -->
122122

123123

124124
Another example, a straightforward operation expressed well via broadcasting:
@@ -137,6 +137,8 @@ d2 = @avx @. a + B * c′;
137137
can be optimized in a similar manner to BLAS, albeit to a much smaller degree because the naive version already benefits from vectorization (unlike the naive BLAS).
138138

139139

140+
<!-- You can also use `\ast` to for a lazy matrix multiplication that can fuse with broadcasts. `.\ast` behaves similarly, to allow it's arguments to -->
141+
140142

141143

142144
Originally, LoopVectorization only provided a simple, dumb, transform on a single loop using the `@vectorize` macro. This transformation took element type and unroll factor arguments, performing no analysis of the loop, simply applying the specified arguments.

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr, mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd
55
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod
66
using Base.Broadcast: Broadcasted, DefaultArrayStyle
7-
using LinearAlgebra: Adjoint
7+
using LinearAlgebra: Adjoint, Transpose
88
using MacroTools: prewalk, postwalk
99

1010

src/broadcast.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,31 @@ function add_broadcast!(
9595
ref = ArrayReference(bcname, fulldims, Ref{Bool}(false))
9696
add_load!(ls, destname, ref, elementbytes)::Operation
9797
end
98+
function add_broadcast_adjoint_array!(
99+
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{A}, elementbytes::Int = 8
100+
) where {T,N,A<:AbstractArray{T,N}}
101+
parent = gensym(:parent)
102+
pushpreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
103+
ref = ArrayReference(parent, Union{Symbol,Int}[loopsyms[N + 1 - n] for n 1:N], Ref{Bool}(false))
104+
add_load!( ls, destname, ref, elementbytes )::Operation
105+
end
106+
function add_broadcast_adjoint_array!(
107+
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{<:AbstractVector}, elementbytes::Int = 8
108+
)
109+
ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[2]], Ref{Bool}(false))
110+
add_load!( ls, destname, ref, elementbytes )
111+
end
98112
function add_broadcast!(
99113
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
100114
::Type{Adjoint{T,A}}, elementbytes::Int = 8
101-
) where {T, N, A <: AbstractArray{T,N}}
102-
ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[N + 1 - n] for n 1:N], Ref{Bool}(false))
103-
add_load!( ls, destname, ref, elementbytes )::Operation
115+
) where {T, A <: AbstractArray{T}}
116+
add_broadcast_adjoint_array!( ls, destname, bcname, loopsyms, A, elementbytes )
104117
end
105118
function add_broadcast!(
106119
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
107-
::Type{Adjoint{T,V}}, elementbytes::Int = 8
108-
) where {T, V <: AbstractVector{T}}
109-
ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[2]], Ref{Bool}(false))
110-
add_load!( ls, destname, ref, elementbytes )
120+
::Type{Transpose{T,A}}, elementbytes::Int = 8
121+
) where {T, A <: AbstractArray{T}}
122+
add_broadcast_adjoint_array!( ls, destname, bcname, loopsyms, A, elementbytes )
111123
end
112124
function add_broadcast!(
113125
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},

src/lowering.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ function lower_load_unrolled!(
8888
else
8989
sn = findfirst(x -> x === unrolled, loopdependencies(op))::Int
9090
ustrides = Expr(:call, lv(:vmul), Expr(:call, :stride, ptr, sn), Expr(:call, lv(:vrange), W))
91+
ustride = gensym(:ustride)
92+
push!(q.args, Expr(:(=), ustride, ustrides))
9193
for u 0:U-1
92-
instrcall = Expr(:call, lv(:gather), ptr, Expr(:call,lv(:vadd),mem_offset(op, u, W, unrolled),ustrides))
94+
instrcall = Expr(:call, lv(:gather), ptr, mem_offset(op, u, W, unrolled), ustride)
9395
if mask !== nothing && u == U - 1
9496
push!(instrcall.args, mask)
9597
end
@@ -207,7 +209,7 @@ function lower_store_unrolled!(
207209
sn = findfirst(x -> x === unrolled, loopdependencies(op))::Int
208210
ustrides = Expr(:call, lv(:vmul), Expr(:call, :stride, ptr, sn), Expr(:call, lv(:vrange), W))
209211
for u 0:U-1
210-
instrcall = Expr(:call, lv(:scatter!), ptr, Symbol("##",var,:_,u), Expr(:call,lv(:vadd),mem_offset(op,u,W,unrolled),ustrides))
212+
instrcall = Expr(:call, lv(:scatter!), ptr, mem_offset(op,u,W,unrolled), ustrides, Symbol("##",var,:_,u))
211213
if mask !== nothing && u == U - 1
212214
push!(instrcall.args, mask)
213215
end

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using LoopVectorization
3+
using LinearAlgebra
34

45

56
@testset "LoopVectorization.jl" begin
@@ -356,6 +357,16 @@ end
356357
D1 = C .+ A * B;
357358
D2 = @avx C .+ A B;
358359
@test D1 D2
360+
361+
D3 = exp.(B')
362+
D4 = @avx exp.(B')
363+
@test D3 D4
364+
365+
fill!(D3, -1e3); fill!(D4, 9e9)
366+
Bt = Transpose(B)
367+
@. D3 = exp(Bt)
368+
@avx @. D4 = exp(Bt)
369+
@test D3 D4
359370
end
360371
end
361372

0 commit comments

Comments
 (0)