Skip to content

Commit a9c1895

Browse files
committed
Add Zygote support for vmap.
1 parent fb45ed3 commit a9c1895

File tree

7 files changed

+179
-21
lines changed

7 files changed

+179
-21
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.9.18"
4+
version = "0.9.19"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
910
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
13+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1214
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
1315
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1416
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"

src/LoopVectorization.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ using ArrayInterface
3232
using ArrayInterface: OptionallyStaticUnitRange, Zero, One#, static_length
3333
const Static = ArrayInterface.StaticInt
3434

35+
using Requires
3536

3637

3738
export LowDimArray, stridedpointer,
3839
@avx, @_avx, *ˡ, _avx_!,
39-
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!,
40+
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!, tanh_fast,
4041
vfilter, vfilter!, vmapreduce, vreduce
4142

4243
@inline unwrap(::Val{N}) where {N} = N
@@ -45,7 +46,6 @@ export LowDimArray, stridedpointer,
4546

4647
const VECTORWIDTHSYMBOL, ELTYPESYMBOL = Symbol("##Wvecwidth##"), Symbol("##Tloopeltype##")
4748

48-
4949
include("vectorizationbase_compat/contract_pass.jl")
5050
include("vectorizationbase_compat/subsetview.jl")
5151
include("closeopen.jl")
@@ -96,4 +96,8 @@ LoopVectorization
9696
include("precompile.jl")
9797
_precompile_()
9898

99+
function __init__()
100+
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include("vmap_grad.jl")
101+
end
102+
99103
end # module

src/map.jl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,33 +47,40 @@ function vmap_singlethread!(
4747
V = VectorizationBase.pick_vector_width_val(T)
4848
W = unwrap(V)
4949
st = VectorizationBase.static_sizeof(T)
50-
zero_index = MM{W}(Static(0), st)
51-
while i < N - ((W << 2) - 1)
50+
UNROLL = 4
51+
LOG2UNROLL = 2
52+
while i < N - ((W << LOG2UNROLL) - 1)
5253

53-
# vstore!(stridedpointer(B), VectorizationBase.VecUnroll((v1,v2,v3)), VectorizationBase.Unroll{AU,1,3,AV,W64,zero(UInt)}((i, j, k)))
54-
# vload(stridedpointer(B), VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((i,)))
55-
56-
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((i,))
54+
index = VectorizationBase.Unroll{1,1,UNROLL,1,W,0x0000000000000000}((i,))
5755
v = f(vload.(ptrargs, index)...)
5856
if NonTemporal
5957
vstorent!(ptry, v, index)
6058
else
6159
vnoaliasstore!(ptry, v, index)
6260
end
63-
i = vadd_fast(i, 4W)
61+
i = vadd_fast(i, StaticInt{UNROLL}() * W)
6462
end
65-
while i < N - (W - 1) # stops at 16 when
66-
vᵣ = f(vload.(ptrargs, ((MM{W}(i),),))...)
67-
if NonTemporal
68-
vstorent!(ptry, vᵣ, (MM{W}(i),))
69-
else
70-
vnoaliasstore!(ptry, vᵣ, (MM{W}(i),))
63+
if Base.libllvm_version v"11"
64+
Nm1 = vsub_fast(N, 1)
65+
while i < N # stops at 16 when
66+
m = mask(V, i, Nm1)
67+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
68+
i = vadd_fast(i, W)
69+
end
70+
else
71+
while i < N - (W - 1) # stops at 16 when
72+
vᵣ = f(vload.(ptrargs, ((MM{W}(i),),))...)
73+
if NonTemporal
74+
vstorent!(ptry, vᵣ, (MM{W}(i),))
75+
else
76+
vnoaliasstore!(ptry, vᵣ, (MM{W}(i),))
77+
end
78+
i = vadd_fast(i, W)
79+
end
80+
if i < N
81+
m = mask(T, N & (W - 1))
82+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
7183
end
72-
i = vadd_fast(i, W)
73-
end
74-
if i < N
75-
m = mask(T, N & (W - 1))
76-
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
7784
end
7885
y
7986
end

src/vmap_grad.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
2+
using ForwardDiff
3+
using VectorizationBase: AbstractSIMD
4+
5+
@generated function SLEEFPirates.tanh_fast(x::ForwardDiff.Dual{T,S,N}) where {T,S,N}
6+
quote
7+
$(Expr(:meta,:inline))
8+
t = tanh_fast(x.value)
9+
∂t = vfnmadd_fast(t, t, one(S))
10+
p = x.partials
11+
ForwardDiff.Dual(t, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> mul_fast(∂t, p[n])))
12+
end
13+
end
14+
function ChainRulesCore.rrule(::typeof(tanh_fast), x)
15+
t = tanh_fast(x)
16+
t, y -> (ChainRulesCore.Zero(), mul_fast(vfnmadd_fast(t, t, one(t)), y))
17+
end
18+
19+
20+
@generated function init_dual(v::Tuple{Vararg{AbstractSIMD,A}}) where {A}
21+
res = Expr(:tuple)
22+
q = Expr(:block, Expr(:meta,:inline))
23+
for a 1:A
24+
v_a = Symbol(:v_,a)
25+
push!(q.args, Expr(:(=), v_a, Expr(:ref, :v, a)))
26+
partials = Expr(:tuple)
27+
for i 1:A
28+
push!(partials.args, Expr(:call, i == a ? :one : :zero, v_a))
29+
end
30+
push!(res.args, :(ForwardDiff.Dual($v_a, ForwardDiff.Partials($partials))))
31+
end
32+
push!(q.args, res)
33+
q
34+
end
35+
36+
@generated function dual_store!(∂p::Tuple{Vararg{AbstractStridedPointer,A}}, p::AbstractStridedPointer, ∂v, im::Vararg{Any,N}) where {A,N}
37+
quote
38+
$(Expr(:meta,:inline))
39+
v = ∂v.value
40+
= ∂v.partials
41+
VectorizationBase.vnoaliasstore!(p, v, im...)
42+
Base.Cartesian.@nexprs $A a -> VectorizationBase.vnoaliasstore!(∂p[a], ∂[a], im...)
43+
nothing
44+
end
45+
end
46+
@generated function dual_store!(∂p::Tuple{Vararg{AbstractStridedPointer,A}}, p::AbstractStridedPointer, ∂v, im::Vararg{Any,N}) where {A,N}
47+
quote
48+
$(Expr(:meta,:inline))
49+
v = ∂v.value
50+
= ∂v.partials
51+
VectorizationBase.vnoaliasstore!(p, v, im...)
52+
Base.Cartesian.@nexprs $A a -> VectorizationBase.vnoaliasstore!(∂p[a], ∂[a], im...)
53+
nothing
54+
end
55+
end
56+
57+
function ∂vmap_singlethread!(
58+
f::F, ∂y::Tuple{Vararg{DenseArray{T},A}}, y::DenseArray{T},
59+
args::Vararg{<:DenseArray{<:Base.HWReal},A}
60+
) where {F,T <: Base.HWReal, A}
61+
N = length(y)
62+
ptry = VectorizationBase.zero_offsets(stridedpointer(y))
63+
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
64+
ptr∂y = VectorizationBase.zero_offsets.(stridedpointer.(∂y))
65+
66+
i = 0
67+
V = VectorizationBase.pick_vector_width_val(T)
68+
W = Int(V)
69+
st = VectorizationBase.static_sizeof(T)
70+
zero_index = MM{W}(StaticInt(0), st)
71+
while i < N - ((W << 2) - 1)
72+
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((i,))
73+
v = f(init_dual(vload.(ptrargs, index))...)
74+
dual_store!(ptr∂y, ptry, v, index)
75+
i = vadd_fast(i, 4W)
76+
end
77+
while i < N - (W - 1)
78+
vᵣ = f(init_dual(vload.(ptrargs, ((MM{W}(i),),)))...)
79+
dual_store!(ptr∂y, ptry, vᵣ, (MM{W}(i),))
80+
i = vadd_fast(i, W)
81+
end
82+
if i < N
83+
m = mask(T, N & (W - 1))
84+
dual_store!(ptr∂y, ptry, f(init_dual(vload.(ptrargs, ((MM{W}(i),),), m))...), (MM{W}(i,),), m)
85+
end
86+
nothing
87+
end
88+
89+
90+
struct SIMDMapBack{K,T<:Tuple{Vararg{Any,K}}}
91+
jacs::T
92+
end
93+
@generated function (b::SIMDMapBack{K,T})(Δ::A) where {K,T,A}
94+
preloop = Expr(:block, :(jacs = b.jacs))
95+
loop_body = Expr(:block, :(Δᵢ = Δ[i]))
96+
ret = Expr(:tuple, ChainRulesCore.Zero(), ChainRulesCore.Zero())
97+
for k 1:K
98+
jₖ = Symbol(:j_, k)
99+
push!(preloop.args, :($jₖ = jacs[$k]))
100+
push!(loop_body.args, :($jₖ[i] *= Δᵢ))
101+
push!(ret.args, jₖ)
102+
end
103+
quote
104+
$preloop
105+
@avx for i eachindex(Δ)
106+
$loop_body
107+
end
108+
$ret
109+
end
110+
end
111+
112+
function ChainRulesCore.rrule(::typeof(vmap), f::F, args::Vararg{Any,K}) where {F,K}
113+
out = similar(first(args))
114+
jacs = map(similar, args)
115+
∂vmap_singlethread!(f, jacs, out, args...)
116+
out, SIMDMapBack(jacs)
117+
end
118+
119+

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
66
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ const START_TIME = time()
2727

2828
@time include("check_empty.jl")
2929

30+
@time include("zygote.jl")
31+
3032
@time include("offsetarrays.jl")
3133

3234
@time include("tensors.jl")

test/zygote.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using Zygote
2+
zygotefun(a,b,c) = log(tanh_fast(a) + 1 + exp(b - c))
3+
@testset "Zygote" begin
4+
@show @__LINE__
5+
6+
for T (Float32,Float64)
7+
x = randn(T, 217); y = randn(T, 217); z = randn(T, 217);
8+
9+
# Test 1-arg and `tanh_fast` vs `tanh`
10+
gtref = gradient(x -> sum(map(tanh, x)), x);
11+
gtlv = gradient(x -> sum(vmap(tanh_fast, x)), x);
12+
13+
@test only(gtref) only(gtlv)
14+
15+
# Test multiple arguments
16+
gref = gradient(xyz -> sum(map(zygotefun, xyz...)), (x, y, z));
17+
glv = gradient(xyz -> sum(vmap(zygotefun, xyz...)), (x, y, z));
18+
19+
@test all(map(, only(gref), only(glv)))
20+
end
21+
end
22+
23+

0 commit comments

Comments
 (0)