Skip to content

Commit 17bd1ed

Browse files
committed
Set correct dual types in simd dual methods
1 parent 63d50ef commit 17bd1ed

File tree

2 files changed

+48
-48
lines changed

2 files changed

+48
-48
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.76"
4+
version = "0.12.77"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,63 @@
11
import .ForwardDiff
22

33
@generated function SLEEFPirates.tanh_fast(x::ForwardDiff.Dual{T,S,N}) where {T,S,N}
4-
quote
5-
$(Expr(:meta,:inline))
6-
t = tanh_fast(x.value)
7-
∂t = vfnmadd_fast(t, t, one(S))
8-
p = x.partials
9-
ForwardDiff.Dual(t, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> mul_fast(∂t, p[n])))
10-
end
4+
quote
5+
$(Expr(:meta,:inline))
6+
t = tanh_fast(x.value)
7+
∂t = vfnmadd_fast(t, t, one(S))
8+
p = x.partials
9+
ForwardDiff.Dual{T}(t, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> mul_fast(∂t, p[n])))
10+
end
1111
end
1212
@generated function SLEEFPirates.sigmoid_fast(x::ForwardDiff.Dual{T,S,N}) where {T,S,N}
13-
quote
14-
$(Expr(:meta,:inline))
15-
s = sigmoid_fast(x.value)
16-
∂s = vfnmadd_fast(s,s,s)
17-
p = x.partials
18-
ForwardDiff.Dual(s, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> mul_fast(∂s, p[n])))
19-
end
13+
quote
14+
$(Expr(:meta,:inline))
15+
s = sigmoid_fast(x.value)
16+
∂s = vfnmadd_fast(s,s,s)
17+
p = x.partials
18+
ForwardDiff.Dual{T}(s, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> mul_fast(∂s, p[n])))
19+
end
2020
end
2121
@generated function VectorizationBase.relu(x::ForwardDiff.Dual{T,S,N}) where {T,S,N}
22-
quote
23-
$(Expr(:meta,:inline))
24-
v = x.value
25-
z = zero(v)
26-
cmp = v < z
27-
r = ifelse(cmp, z, v)
28-
p = x.partials
29-
ForwardDiff.Dual(r, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, z, p[n])))
30-
end
22+
quote
23+
$(Expr(:meta,:inline))
24+
v = x.value
25+
z = zero(v)
26+
cmp = v < z
27+
r = ifelse(cmp, z, v)
28+
p = x.partials
29+
ForwardDiff.Dual{T}(r, ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, z, p[n])))
30+
end
3131
end
3232
@generated function init_dual(v::Tuple{Vararg{AbstractSIMD,A}}) where {A}
33-
res = Expr(:tuple)
34-
q = Expr(:block, Expr(:meta,:inline))
35-
for a 1:A
36-
v_a = Symbol(:v_,a)
37-
push!(q.args, Expr(:(=), v_a, Expr(:ref, :v, a)))
38-
partials = Expr(:tuple)
39-
for i 1:A
40-
push!(partials.args, Expr(:call, i == a ? :one : :zero, v_a))
41-
end
42-
push!(res.args, :(ForwardDiff.Dual($v_a, ForwardDiff.Partials($partials))))
33+
res = Expr(:tuple)
34+
q = Expr(:block, Expr(:meta,:inline))
35+
for a 1:A
36+
v_a = Symbol(:v_,a)
37+
push!(q.args, Expr(:(=), v_a, Expr(:ref, :v, a)))
38+
partials = Expr(:tuple)
39+
for i 1:A
40+
push!(partials.args, Expr(:call, i == a ? :one : :zero, v_a))
4341
end
44-
push!(q.args, res)
45-
q
42+
push!(res.args, :(ForwardDiff.Dual($v_a, ForwardDiff.Partials($partials))))
43+
end
44+
push!(q.args, res)
45+
q
4646
end
4747
@generated function dual_store!(∂p::Tuple{Vararg{AbstractStridedPointer,A}}, p::AbstractStridedPointer, ∂v, im::Vararg{Any,N}) where {A,N}
48-
quote
49-
$(Expr(:meta,:inline))
50-
v = ∂v.value
51-
= ∂v.partials
52-
Base.Cartesian.@nextract $N im im
53-
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! p v im # store
54-
Base.Cartesian.@nexprs $A a -> begin # for each of `A` partials
55-
∂p_a = ∂p[a]
56-
∂_a = ∂[a]
57-
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! ∂p_a ∂_a im # store
58-
end
59-
nothing
48+
quote
49+
$(Expr(:meta,:inline))
50+
v = ∂v.value
51+
= ∂v.partials
52+
Base.Cartesian.@nextract $N im im
53+
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! p v im # store
54+
Base.Cartesian.@nexprs $A a -> begin # for each of `A` partials
55+
∂p_a = ∂p[a]
56+
∂_a = ∂[a]
57+
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! ∂p_a ∂_a im # store
6058
end
59+
nothing
60+
end
6161
end
6262

6363

0 commit comments

Comments
 (0)