Skip to content

Commit 473d881

Browse files
committed
add sigmoid and relu support
1 parent 4739539 commit 473d881

File tree

4 files changed

+65
-26
lines changed

4 files changed

+65
-26
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.9.19"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
109
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
@@ -20,9 +19,10 @@ ArrayInterface = "2.14.12"
2019
DocStringExtensions = "0.8"
2120
IfElse = "0.1"
2221
OffsetArrays = "1.4.1, 1.5"
23-
SLEEFPirates = "0.6.4"
22+
Requires = "1"
23+
SLEEFPirates = "0.6.6"
2424
UnPack = "1"
25-
VectorizationBase = "0.15.2"
25+
VectorizationBase = "0.15.3"
2626
julia = "1.5"
2727

2828
[extras]

src/LoopVectorization.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ using Requires
3737

3838
export LowDimArray, stridedpointer,
3939
@avx, @_avx, *ˡ, _avx_!,
40-
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!, tanh_fast,
40+
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!, tanh_fast, sigmoid_fast,
4141
vfilter, vfilter!, vmapreduce, vreduce
4242

4343
@inline unwrap(::Val{N}) where {N} = N
@@ -96,8 +96,12 @@ LoopVectorization
9696
include("precompile.jl")
9797
_precompile_()
9898

99+
# import ChainRulesCore, ForwardDiff
100+
# include("vmap_grad.jl")
99101
function __init__()
100-
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include("vmap_grad.jl")
102+
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin
103+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("vmap_grad.jl")
104+
end
101105
end
102106

103107
end # module

src/vmap_grad.jl

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11

2-
using ForwardDiff
32
using VectorizationBase: AbstractSIMD
43

54
@generated function SLEEFPirates.tanh_fast(x::ForwardDiff.Dual{T,S,N}) where {T,S,N}
@@ -13,7 +12,46 @@ using VectorizationBase: AbstractSIMD
1312
end
1413
function ChainRulesCore.rrule(::typeof(tanh_fast), x)
1514
t = tanh_fast(x)
16-
t, y -> (ChainRulesCore.Zero(), mul_fast(vfnmadd_fast(t, t, one(t)), y))
15+
= let t = t
16+
y -> (ChainRulesCore.Zero(), mul_fast(vfnmadd_fast(t, t, one(t)), y))
17+
end
18+
t, ∂
19+
end
20+
@generated function SLEEFPirates.sigmoid_fast(x::ForwardDiff.Dual{T,S,N}) where {T,S,N}
21+
quote
22+
$(Expr(:meta,:inline))
23+
s = sigmoid_fast(x.value)
24+
∂s = vfnmadd_fast(s,s,s)
25+
p = x.partials
26+
ForwardDiff.Dual(s, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> mul_fast(∂s, p[n])))
27+
end
28+
end
29+
function ChainRulesCore.rrule(::typeof(sigmoid_fast), x)
30+
s = sigmoid_fast(x)
31+
= let s = s
32+
y -> (ChainRulesCore.Zero(), mul_fast(vfnmadd_fast(s, s, s), y))
33+
end
34+
s, ∂
35+
end
36+
@generated function VectorizationBase.relu(x::ForwardDiff.Dual{T,S,N}) where {T,S,N}
37+
quote
38+
$(Expr(:meta,:inline))
39+
v = x.value
40+
z = zero(v)
41+
cmp = v < z
42+
r = ifelse(cmp, z, v)
43+
p = x.partials
44+
ForwardDiff.Dual(r, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, z, p[n])))
45+
end
46+
end
47+
function ChainRulesCore.rrule(::typeof(relu), v)
48+
z = zero(v)
49+
cmp = v < z
50+
r = ifelse(cmp, z, v)
51+
= let cmp = cmp
52+
y -> (ChainRulesCore.Zero(), ifelse(cmp, zero(y), y))
53+
end
54+
r, ∂
1755
end
1856

1957

@@ -33,16 +71,6 @@ end
3371
q
3472
end
3573

36-
@generated function dual_store!(∂p::Tuple{Vararg{AbstractStridedPointer,A}}, p::AbstractStridedPointer, ∂v, im::Vararg{Any,N}) where {A,N}
37-
quote
38-
$(Expr(:meta,:inline))
39-
v = ∂v.value
40-
= ∂v.partials
41-
VectorizationBase.vnoaliasstore!(p, v, im...)
42-
Base.Cartesian.@nexprs $A a -> VectorizationBase.vnoaliasstore!(∂p[a], ∂[a], im...)
43-
nothing
44-
end
45-
end
4674
@generated function dual_store!(∂p::Tuple{Vararg{AbstractStridedPointer,A}}, p::AbstractStridedPointer, ∂v, im::Vararg{Any,N}) where {A,N}
4775
quote
4876
$(Expr(:meta,:inline))

test/zygote.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
11
using Zygote
2-
zygotefun(a,b,c) = log(tanh_fast(a) + 1 + exp(b - c))
2+
zygotefun2(a,b) = sigmoid_fast(a + b) + LoopVectorization.relu(a * b) * tanh(a - b)
3+
zygotefun3(a,b,c) = log(tanh_fast(a*b) + 1 + exp(b - c))
34
@testset "Zygote" begin
45
@show @__LINE__
56

67
for T (Float32,Float64)
78
x = randn(T, 217); y = randn(T, 217); z = randn(T, 217);
8-
9+
w = randn(T, 217)
910
# Test 1-arg and `tanh_fast` vs `tanh`
10-
gtref = gradient(x -> sum(map(tanh, x)), x);
11-
gtlv = gradient(x -> sum(vmap(tanh_fast, x)), x);
11+
gref1 = gradient(x -> dot(w, map(tanh, x)), x);
12+
glv_1 = gradient(x -> dot(w, vmap(tanh_fast, x)), x);
13+
14+
@test only(gref1) only(glv_1)
15+
16+
# Test 2 arguments
17+
# Test 3 arguments
18+
gref2 = gradient(xyz -> dot(w, map(zygotefun2, xyz...)), (x, y));
19+
g_lv2 = gradient(xyz -> dot(w, vmap(zygotefun2, xyz...)), (x, y));
1220

13-
@test only(gtref) only(gtlv)
21+
@test all(map(, only(gref2), only(g_lv2)))
1422

15-
# Test multiple arguments
16-
gref = gradient(xyz -> sum(map(zygotefun, xyz...)), (x, y, z));
17-
glv = gradient(xyz -> sum(vmap(zygotefun, xyz...)), (x, y, z));
23+
gref3 = gradient(xyz -> dot(w, map(zygotefun3, xyz...)), (x, y, z));
24+
g_lv3 = gradient(xyz -> dot(w, vmap(zygotefun3, xyz...)), (x, y, z));
1825

19-
@test all(map(, only(gref), only(glv)))
26+
@test all(map(, only(gref3), only(g_lv3)))
2027
end
2128
end
2229

0 commit comments

Comments
 (0)