diff --git a/Project.toml b/Project.toml index bc474cbb..8d2e94f0 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HostCPUFeatures = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LayoutPointers = "10f19ff3-798f-405d-979b-55457f8fc047" @@ -28,8 +27,8 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [extensions] ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"] @@ -57,4 +56,7 @@ StaticArrayInterface = "1" ThreadingUtilities = "0.5" UnPack = "1" VectorizationBase = "0.21.67" -julia = "1.6" +julia = "1.10" + +[extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl index 26227f69..2da074eb 100644 --- a/ext/ForwardDiffExt.jl +++ b/ext/ForwardDiffExt.jl @@ -9,6 +9,7 @@ using LoopVectorization: AbstractSIMD, AbstractStridedPointer, relu, + leakyrelu, vmap, VectorizationBase, vmapt, @@ -140,6 +141,27 @@ end ) end end + +@generated function VectorizationBase.leakyrelu( + x::ForwardDiff.Dual{T,S,N}, + a = 0.01 +) where {T,S,N} + quote + $(Expr(:meta, :inline)) + v = x.value + z = zero(v) + + α = convert(typeof(v), a) + cmp = v < z + r = ifelse(cmp, α * v, v) + p = x.partials + ForwardDiff.Dual{T}( + r, + ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, α * p[n], p[n])) + ) + end +end + @generated function VectorizationBase.relu( x::ForwardDiff.Dual{T,S,N} ) where {T,S,N} @@ -171,6 +193,7 @@ end ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p)) end end + @generated function _ifelse( m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}}, x::Number, diff --git a/src/LoopVectorization.jl b/src/LoopVectorization.jl index 4af0f832..f3c9997a 100644 --- a/src/LoopVectorization.jl +++ b/src/LoopVectorization.jl @@ -89,6 +89,7 @@ using VectorizationBase: vsub_fast, vmul_fast, relu, + leakyrelu, stridedpointer, _vload, _vstore!, diff --git a/test/forwarddiffext.jl b/test/forwarddiffext.jl index b4b905c7..3f965ffd 100644 --- a/test/forwarddiffext.jl +++ b/test/forwarddiffext.jl @@ -16,21 +16,6 @@ function tovec(x::ForwardDiff.Dual{T,V,N}) where {T,V,N} return ret end -if LoopVectorization.ifelse !== Base.ifelse - @inline function NNlib.leakyrelu( - x::LoopVectorization.AbstractSIMD, - a = NNlib.oftf(x, NNlib.leakyrelu_a), - ) - LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower - end - @inline function NNlib.leakyrelu( - x::ForwardDiff.Dual{<:Any,<:LoopVectorization.AbstractSIMD}, - a = NNlib.oftf(x, NNlib.leakyrelu_a), - ) - LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower - end -end - vx0 = randnvec() vx1 = randnvec() vx2 = randnvec() @@ -46,7 +31,13 @@ vu2 = VecUnroll((vx4, vx5)) vud = ForwardDiff.Dual(vu0, vu1, vu2) -@test reinterpret(Float64, tovec(NNlib.leakyrelu(vd0))) ≈ - reinterpret(Float64, NNlib.leakyrelu.(tovec(vd0))) -@test reinterpret(Float64, tovec(NNlib.leakyrelu(vud))) ≈ - reinterpret(Float64, NNlib.leakyrelu.(tovec(vud))) + +@test reinterpret(Float64, tovec(VectorizationBase.relu(vd0))) ≈ + reinterpret(Float64, VectorizationBase.relu.(tovec(vd0))) +@test reinterpret(Float64, tovec(VectorizationBase.relu(vud))) ≈ + reinterpret(Float64, VectorizationBase.relu.(tovec(vud))) + +@test reinterpret(Float64, tovec(VectorizationBase.leakyrelu(vd0))) ≈ + reinterpret(Float64, VectorizationBase.leakyrelu.(tovec(vd0))) +@test reinterpret(Float64, tovec(VectorizationBase.leakyrelu(vud))) ≈ + reinterpret(Float64, VectorizationBase.leakyrelu.(tovec(vud)))