Skip to content

Commit 82c2318

Browse files
committed
Test leakyrelu
1 parent 254546f commit 82c2318

File tree

11 files changed

+105
-41
lines changed

11 files changed

+105
-41
lines changed

benchmark/looptests.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ function jgemm!(𝐂, 𝐀ᵀ::Adjoint, 𝐁ᵀ::Adjoint)
7676
end
7777
end
7878
gemmavx!(𝐂, 𝐀, 𝐁) = @turbo for m indices((𝐀, 𝐂), 1), n indices((𝐁, 𝐂), 2)
79-
𝐂ₘₙ = zero(eltype(𝐂))
80-
for k indices((𝐀, 𝐁), (2, 1))
81-
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
82-
end
83-
𝐂[m, n] = 𝐂ₘₙ
79+
𝐂ₘₙ = zero(eltype(𝐂))
80+
for k indices((𝐀, 𝐁), (2, 1))
81+
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
8482
end
83+
𝐂[m, n] = 𝐂ₘₙ
84+
end
8585
function gemmavx!(
8686
Cc::AbstractMatrix{Complex{T}},
8787
Ac::AbstractMatrix{Complex{T}},
@@ -102,12 +102,12 @@ function gemmavx!(
102102
end
103103
end
104104
gemmavxt!(𝐂, 𝐀, 𝐁) = @tturbo for m indices((𝐀, 𝐂), 1), n indices((𝐁, 𝐂), 2)
105-
𝐂ₘₙ = zero(eltype(𝐂))
106-
for k indices((𝐀, 𝐁), (2, 1))
107-
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
108-
end
109-
𝐂[m, n] = 𝐂ₘₙ
105+
𝐂ₘₙ = zero(eltype(𝐂))
106+
for k indices((𝐀, 𝐁), (2, 1))
107+
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
110108
end
109+
𝐂[m, n] = 𝐂ₘₙ
110+
end
111111
function gemmavxt!(
112112
Cc::AbstractMatrix{Complex{T}},
113113
Ac::AbstractMatrix{Complex{T}},
@@ -204,11 +204,11 @@ function jdot3avx(x, A, y)
204204
s
205205
end
206206
jvexp!(b, a) = @inbounds for i eachindex(a)
207-
b[i] = exp(a[i])
208-
end
207+
b[i] = exp(a[i])
208+
end
209209
jvexpavx!(b, a) = @turbo for i eachindex(a)
210-
b[i] = exp(a[i])
211-
end
210+
b[i] = exp(a[i])
211+
end
212212
function jsvexp(a)
213213
s = zero(eltype(a))
214214
@inbounds for i eachindex(a)
@@ -242,12 +242,12 @@ function jgemv!(𝐲, 𝐀ᵀ::Adjoint, 𝐱)
242242
end
243243
end
244244
jgemvavx!(𝐲, 𝐀, 𝐱) = @turbo for i eachindex(𝐲)
245-
𝐲ᵢ = zero(eltype(𝐲))
246-
for j eachindex(𝐱)
247-
𝐲ᵢ += 𝐀[i, j] * 𝐱[j]
248-
end
249-
𝐲[i] = 𝐲ᵢ
245+
𝐲ᵢ = zero(eltype(𝐲))
246+
for j eachindex(𝐱)
247+
𝐲ᵢ += 𝐀[i, j] * 𝐱[j]
250248
end
249+
𝐲[i] = 𝐲ᵢ
250+
end
251251
function jvar!(𝐬², 𝐀, x̄)
252252
@.= zero(eltype(𝐬²))
253253
@inbounds @fastmath for i 1:size(𝐀, 2)
@@ -258,14 +258,14 @@ function jvar!(𝐬², 𝐀, x̄)
258258
end
259259
end
260260
jvaravx!(𝐬², 𝐀, x̄) = @turbo for j eachindex(𝐬²)
261-
𝐬²ⱼ = zero(eltype(𝐬²))
262-
x̄ⱼ = x̄[j]
263-
for i 1:size(𝐀, 2)
264-
δ = 𝐀[j, i] - x̄ⱼ
265-
𝐬²ⱼ += δ * δ
266-
end
267-
𝐬²[j] = 𝐬²ⱼ
261+
𝐬²ⱼ = zero(eltype(𝐬²))
262+
x̄ⱼ = x̄[j]
263+
for i 1:size(𝐀, 2)
264+
δ = 𝐀[j, i] - x̄ⱼ
265+
𝐬²ⱼ += δ * δ
268266
end
267+
𝐬²[j] = 𝐬²ⱼ
268+
end
269269
japlucBc!(D, a, B, c) = @. D = a + B * c';
270270
japlucBcavx!(D, a, B, c) = @turbo @. D = a + B * c';
271271

benchmark/plotbenchmarks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ else
2929
# const COLOR_MAP = Dict{String,RGB{Float64}}()
3030
# const COLOR_MAP = Dict{String,RGB{Colors.N0f8}}()
3131
const COLOR_MAP64 = Dict{String,RGB{Float64}}()
32-
getcolor(s::String) = get!(COLOR_MAP64, s) do
32+
getcolor(s::String) =
33+
get!(COLOR_MAP64, s) do
3334
COLORS[length(COLOR_MAP64)+1]
3435
end
3536
replace_and(str) = replace(str, '&' => "with")

ext/ForwardDiffExt.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,28 @@ end
197197
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
198198
end
199199
end
200-
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::Number) = _ifelse(m, x, y)
201-
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::ForwardDiff.Dual) = _ifelse(m, x, y)
202-
@inline ifelse(m::AbstractMask, y::Number, x::ForwardDiff.Dual) = _ifelse(m, y, x)
200+
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::Number) =
201+
_ifelse(m, x, y)
202+
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::ForwardDiff.Dual) =
203+
_ifelse(m, x, y)
204+
@inline ifelse(m::AbstractMask, y::Number, x::ForwardDiff.Dual) =
205+
_ifelse(m, y, x)
203206

204-
@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, x::ForwardDiff.Dual, y::Number) = _ifelse(m, x, y)
205-
@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, x::ForwardDiff.Dual, y::ForwardDiff.Dual) = _ifelse(m, x, y)
206-
@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, y::Number, x::ForwardDiff.Dual) = _ifelse(m, y, x)
207+
@inline ifelse(
208+
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
209+
x::ForwardDiff.Dual,
210+
y::Number
211+
) = _ifelse(m, x, y)
212+
@inline ifelse(
213+
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
214+
x::ForwardDiff.Dual,
215+
y::ForwardDiff.Dual
216+
) = _ifelse(m, x, y)
217+
@inline ifelse(
218+
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
219+
y::Number,
220+
x::ForwardDiff.Dual
221+
) = _ifelse(m, y, x)
207222

208223
@inline function SLEEFPirates.softplus(x::ForwardDiff.Dual{TAG}) where {TAG}
209224
val = ForwardDiff.value(x)

src/LoopVectorization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ using VectorizationBase:
108108
contract_or,
109109
collapse_or,
110110
max_mask,
111-
maybestaticsize,zero_mask
111+
maybestaticsize,
112+
zero_mask
112113

113114
using HostCPUFeatures:
114115
pick_vector_width,

src/codegen/split_loops.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ function add_operation!(
7676
opnew
7777
end
7878

79-
append_if_included!(vnew, vold, included) = for (i, v) vold
79+
append_if_included!(vnew, vold, included) =
80+
for (i, v) vold
8081
id = included[i]
8182
iszero(id) || push!(vnew, (id, v))
8283
end

src/modeling/costs.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ struct Instruction
1313
end
1414
# lower(instr::Instruction) = Expr(:(.), instr.mod, QuoteNode(instr.instr))
1515
# Base.convert(::Type{Expr}, instr::Instruction) = Expr(:(.), instr.mod, QuoteNode(instr.instr))
16-
callexpr(instr::Instruction) = if instr.mod === :LoopVectorization
16+
callexpr(instr::Instruction) =
17+
if instr.mod === :LoopVectorization
1718
Expr(:call, lv(instr.instr))
1819
else#if instr.mod === :Main
1920
Expr(:call, instr.instr)
@@ -563,7 +564,8 @@ function reduction_to_single_vector(x::Float64)
563564
throw("Reduction not found.")
564565
end
565566
end
566-
reduce_to_onevecunroll(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
567+
reduce_to_onevecunroll(x::Float64) =
568+
if x == ADDITIVE_IN_REDUCTIONS
567569
:+
568570
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
569571
:*
@@ -578,7 +580,8 @@ reduce_to_onevecunroll(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
578580
else
579581
throw("Reduction not found.")
580582
end
581-
reduce_number_of_vectors(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
583+
reduce_number_of_vectors(x::Float64) =
584+
if x == ADDITIVE_IN_REDUCTIONS
582585
:contract_add
583586
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
584587
:contract_mul
@@ -593,7 +596,8 @@ reduce_number_of_vectors(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
593596
else
594597
throw("Reduction not found.")
595598
end
596-
reduction_to_scalar(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
599+
reduction_to_scalar(x::Float64) =
600+
if x == ADDITIVE_IN_REDUCTIONS
597601
:vsum
598602
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
599603
:vprod

src/predicates.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ isscopedname(:(Base.Checked.checked_add), (:Base, :Checked), :checked_add)
1111
function isscopedname(ex, modpath, name::Symbol)
1212
isexpr(ex, :(.), 2) &&
1313
(a = ex.args[2]; isa(a, QuoteNode) && a.value === name) &&
14-
hasscope(ex.args[1], modpath)
14+
hasscope(ex.args[1], modpath)
1515
end
1616
hasscope(modex, mod::Symbol) = modex === mod
1717
hasscope(modex, mod::Tuple{Symbol}) = hasscope(modex, mod[1])

src/simdfunctionals/mapreduce.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ end
115115
Vectorized version of `sum`. Providing a function as the first argument
116116
will apply the function to each element of `A` before summing.
117117
"""
118-
@inline vsum(f::F, A::AbstractArray{T}) where {F,T<:NativeTypes} = vmapreduce(f, +, A)
118+
@inline vsum(f::F, A::AbstractArray{T}) where {F,T<:NativeTypes} =
119+
vmapreduce(f, +, A)
119120
@inline vsum(A::AbstractArray{T}) where {T<:NativeTypes} = vsum(identity, A)
120121

121122
length_one_axis(::Base.OneTo) = Base.OneTo(1)

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
4+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
45
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
68
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
79
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
810
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -12,4 +14,5 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1214
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1315
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1518
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/forwarddiffext.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
using NNlib, LoopVectorization, VectorizationBase, ForwardDiff, Test
3+
randnvec() = Vec(ntuple(_ -> randn(), pick_vector_width(Float64))...)
4+
5+
tovec(x::Vec{W,T}) where {W,T} = T[Tuple(x)...]
6+
tovec(x::VecUnroll) = reduce(vcat, map(tovec, VectorizationBase.data(x)))
7+
function tovec(x::ForwardDiff.Dual{T,V,N}) where {T,V,N}
8+
v = tovec(ForwardDiff.value(x))
9+
dv = map(tovec, Tuple(ForwardDiff.partials(x)))
10+
D = ForwardDiff.Dual{T,eltype(v),N}
11+
ret = Vector{D}(undef, length(v))
12+
for i in eachindex(v)
13+
ret[i] = ForwardDiff.Dual(v[i], map(Base.Fix2(Base.getindex, i), dv)...)
14+
end
15+
return ret
16+
end
17+
18+
19+
vx0 = randnvec()
20+
vx1 = randnvec()
21+
vx2 = randnvec()
22+
vx3 = randnvec()
23+
vx4 = randnvec()
24+
vx5 = randnvec()
25+
26+
vd0 = ForwardDiff.Dual(vx0, vx1, vx2, vx3, vx4, vx5)
27+
28+
vu0 = VecUnroll((vx0, vx1))
29+
vu1 = VecUnroll((vx2, vx3))
30+
vu2 = VecUnroll((vx4, vx5))
31+
32+
vud = ForwardDiff.Dual(vu0, vu1, vu2)
33+
34+
@test reinterpret(Float64, tovec(NNlib.leakyrelu(vd0)))
35+
reinterpret(Float64, NNlib.leakyrelu.(tovec(vd0)))
36+
@test reinterpret(Float64, tovec(NNlib.leakyrelu(vud)))
37+
reinterpret(Float64, NNlib.leakyrelu.(tovec(vud)))

0 commit comments

Comments
 (0)