Skip to content

Commit 73fb725

Browse files
authored
Fix and test leakyrelu (#505)
* ignore oftype * broadcast fixes * delete problematic line that accidentally wasn't removed * no print * fix order * ifelse unroll * Test leakyrelu
1 parent a21d6f8 commit 73fb725

File tree

13 files changed

+115
-43
lines changed

13 files changed

+115
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.163"
4+
version = "0.12.164"
55

66

77
[deps]

benchmark/looptests.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ function jgemm!(𝐂, 𝐀ᵀ::Adjoint, 𝐁ᵀ::Adjoint)
7676
end
7777
end
7878
gemmavx!(𝐂, 𝐀, 𝐁) = @turbo for m indices((𝐀, 𝐂), 1), n indices((𝐁, 𝐂), 2)
79-
𝐂ₘₙ = zero(eltype(𝐂))
80-
for k indices((𝐀, 𝐁), (2, 1))
81-
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
82-
end
83-
𝐂[m, n] = 𝐂ₘₙ
79+
𝐂ₘₙ = zero(eltype(𝐂))
80+
for k indices((𝐀, 𝐁), (2, 1))
81+
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
8482
end
83+
𝐂[m, n] = 𝐂ₘₙ
84+
end
8585
function gemmavx!(
8686
Cc::AbstractMatrix{Complex{T}},
8787
Ac::AbstractMatrix{Complex{T}},
@@ -102,12 +102,12 @@ function gemmavx!(
102102
end
103103
end
104104
gemmavxt!(𝐂, 𝐀, 𝐁) = @tturbo for m indices((𝐀, 𝐂), 1), n indices((𝐁, 𝐂), 2)
105-
𝐂ₘₙ = zero(eltype(𝐂))
106-
for k indices((𝐀, 𝐁), (2, 1))
107-
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
108-
end
109-
𝐂[m, n] = 𝐂ₘₙ
105+
𝐂ₘₙ = zero(eltype(𝐂))
106+
for k indices((𝐀, 𝐁), (2, 1))
107+
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
110108
end
109+
𝐂[m, n] = 𝐂ₘₙ
110+
end
111111
function gemmavxt!(
112112
Cc::AbstractMatrix{Complex{T}},
113113
Ac::AbstractMatrix{Complex{T}},
@@ -204,11 +204,11 @@ function jdot3avx(x, A, y)
204204
s
205205
end
206206
jvexp!(b, a) = @inbounds for i eachindex(a)
207-
b[i] = exp(a[i])
208-
end
207+
b[i] = exp(a[i])
208+
end
209209
jvexpavx!(b, a) = @turbo for i eachindex(a)
210-
b[i] = exp(a[i])
211-
end
210+
b[i] = exp(a[i])
211+
end
212212
function jsvexp(a)
213213
s = zero(eltype(a))
214214
@inbounds for i eachindex(a)
@@ -242,12 +242,12 @@ function jgemv!(𝐲, 𝐀ᵀ::Adjoint, 𝐱)
242242
end
243243
end
244244
jgemvavx!(𝐲, 𝐀, 𝐱) = @turbo for i eachindex(𝐲)
245-
𝐲ᵢ = zero(eltype(𝐲))
246-
for j eachindex(𝐱)
247-
𝐲ᵢ += 𝐀[i, j] * 𝐱[j]
248-
end
249-
𝐲[i] = 𝐲ᵢ
245+
𝐲ᵢ = zero(eltype(𝐲))
246+
for j eachindex(𝐱)
247+
𝐲ᵢ += 𝐀[i, j] * 𝐱[j]
250248
end
249+
𝐲[i] = 𝐲ᵢ
250+
end
251251
function jvar!(𝐬², 𝐀, x̄)
252252
@.= zero(eltype(𝐬²))
253253
@inbounds @fastmath for i 1:size(𝐀, 2)
@@ -258,14 +258,14 @@ function jvar!(𝐬², 𝐀, x̄)
258258
end
259259
end
260260
jvaravx!(𝐬², 𝐀, x̄) = @turbo for j eachindex(𝐬²)
261-
𝐬²ⱼ = zero(eltype(𝐬²))
262-
x̄ⱼ = x̄[j]
263-
for i 1:size(𝐀, 2)
264-
δ = 𝐀[j, i] - x̄ⱼ
265-
𝐬²ⱼ += δ * δ
266-
end
267-
𝐬²[j] = 𝐬²ⱼ
261+
𝐬²ⱼ = zero(eltype(𝐬²))
262+
x̄ⱼ = x̄[j]
263+
for i 1:size(𝐀, 2)
264+
δ = 𝐀[j, i] - x̄ⱼ
265+
𝐬²ⱼ += δ * δ
268266
end
267+
𝐬²[j] = 𝐬²ⱼ
268+
end
269269
japlucBc!(D, a, B, c) = @. D = a + B * c';
270270
japlucBcavx!(D, a, B, c) = @turbo @. D = a + B * c';
271271

benchmark/plotbenchmarks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ else
2929
# const COLOR_MAP = Dict{String,RGB{Float64}}()
3030
# const COLOR_MAP = Dict{String,RGB{Colors.N0f8}}()
3131
const COLOR_MAP64 = Dict{String,RGB{Float64}}()
32-
getcolor(s::String) = get!(COLOR_MAP64, s) do
32+
getcolor(s::String) =
33+
get!(COLOR_MAP64, s) do
3334
COLORS[length(COLOR_MAP64)+1]
3435
end
3536
replace_and(str) = replace(str, '&' => "with")

ext/ForwardDiffExt.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ end
157157
end
158158
end
159159

160-
@generated function ifelse(
161-
m::AbstractMask,
160+
@generated function _ifelse(
161+
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
162162
x::ForwardDiff.Dual{TAG,V,P},
163163
y::ForwardDiff.Dual{TAG,V,P}
164164
) where {TAG,V,P}
@@ -171,8 +171,8 @@ end
171171
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
172172
end
173173
end
174-
@generated function ifelse(
175-
m::AbstractMask,
174+
@generated function _ifelse(
175+
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
176176
x::Number,
177177
y::ForwardDiff.Dual{TAG,V,P}
178178
) where {TAG,V,P}
@@ -184,8 +184,8 @@ end
184184
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
185185
end
186186
end
187-
@generated function ifelse(
188-
m::AbstractMask,
187+
@generated function _ifelse(
188+
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
189189
x::ForwardDiff.Dual{TAG,V,P},
190190
y::Number
191191
) where {TAG,V,P}
@@ -197,6 +197,29 @@ end
197197
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
198198
end
199199
end
200+
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::Number) =
201+
_ifelse(m, x, y)
202+
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::ForwardDiff.Dual) =
203+
_ifelse(m, x, y)
204+
@inline ifelse(m::AbstractMask, y::Number, x::ForwardDiff.Dual) =
205+
_ifelse(m, y, x)
206+
207+
@inline ifelse(
208+
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
209+
x::ForwardDiff.Dual,
210+
y::Number
211+
) = _ifelse(m, x, y)
212+
@inline ifelse(
213+
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
214+
x::ForwardDiff.Dual,
215+
y::ForwardDiff.Dual
216+
) = _ifelse(m, x, y)
217+
@inline ifelse(
218+
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
219+
y::Number,
220+
x::ForwardDiff.Dual
221+
) = _ifelse(m, y, x)
222+
200223
@inline function SLEEFPirates.softplus(x::ForwardDiff.Dual{TAG}) where {TAG}
201224
val = ForwardDiff.value(x)
202225
expx = exp(val)

src/LoopVectorization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ using VectorizationBase:
108108
contract_or,
109109
collapse_or,
110110
max_mask,
111-
maybestaticsize,zero_mask
111+
maybestaticsize,
112+
zero_mask
112113

113114
using HostCPUFeatures:
114115
pick_vector_width,

src/codegen/split_loops.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ function add_operation!(
7676
opnew
7777
end
7878

79-
append_if_included!(vnew, vold, included) = for (i, v) vold
79+
append_if_included!(vnew, vold, included) =
80+
for (i, v) vold
8081
id = included[i]
8182
iszero(id) || push!(vnew, (id, v))
8283
end

src/modeling/costs.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ struct Instruction
1313
end
1414
# lower(instr::Instruction) = Expr(:(.), instr.mod, QuoteNode(instr.instr))
1515
# Base.convert(::Type{Expr}, instr::Instruction) = Expr(:(.), instr.mod, QuoteNode(instr.instr))
16-
callexpr(instr::Instruction) = if instr.mod === :LoopVectorization
16+
callexpr(instr::Instruction) =
17+
if instr.mod === :LoopVectorization
1718
Expr(:call, lv(instr.instr))
1819
else#if instr.mod === :Main
1920
Expr(:call, instr.instr)
@@ -563,7 +564,8 @@ function reduction_to_single_vector(x::Float64)
563564
throw("Reduction not found.")
564565
end
565566
end
566-
reduce_to_onevecunroll(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
567+
reduce_to_onevecunroll(x::Float64) =
568+
if x == ADDITIVE_IN_REDUCTIONS
567569
:+
568570
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
569571
:*
@@ -578,7 +580,8 @@ reduce_to_onevecunroll(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
578580
else
579581
throw("Reduction not found.")
580582
end
581-
reduce_number_of_vectors(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
583+
reduce_number_of_vectors(x::Float64) =
584+
if x == ADDITIVE_IN_REDUCTIONS
582585
:contract_add
583586
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
584587
:contract_mul
@@ -593,7 +596,8 @@ reduce_number_of_vectors(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
593596
else
594597
throw("Reduction not found.")
595598
end
596-
reduction_to_scalar(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
599+
reduction_to_scalar(x::Float64) =
600+
if x == ADDITIVE_IN_REDUCTIONS
597601
:vsum
598602
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
599603
:vprod

src/predicates.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ isscopedname(:(Base.Checked.checked_add), (:Base, :Checked), :checked_add)
1111
function isscopedname(ex, modpath, name::Symbol)
1212
isexpr(ex, :(.), 2) &&
1313
(a = ex.args[2]; isa(a, QuoteNode) && a.value === name) &&
14-
hasscope(ex.args[1], modpath)
14+
hasscope(ex.args[1], modpath)
1515
end
1616
hasscope(modex, mod::Symbol) = modex === mod
1717
hasscope(modex, mod::Tuple{Symbol}) = hasscope(modex, mod[1])

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Base.promote_rule(
2727
::Type{UpperBoundedInteger{N,T}},
2828
::Type{T}
2929
) where {N,T<:Base.BitInteger} = T
30-
Base.convert(::Type{T}, i::UpperBoundedInteger) where {T<:Number} =
30+
Base.convert(::Type{T}, i::UpperBoundedInteger) where {T<:Integer} =
3131
convert(T, i.i)
3232
Base.convert(
3333
::Type{UpperBoundedInteger{N,T}},

src/simdfunctionals/mapreduce.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ end
115115
Vectorized version of `sum`. Providing a function as the first argument
116116
will apply the function to each element of `A` before summing.
117117
"""
118-
@inline vsum(f::F, A::AbstractArray{T}) where {F,T<:NativeTypes} = vmapreduce(f, +, A)
118+
@inline vsum(f::F, A::AbstractArray{T}) where {F,T<:NativeTypes} =
119+
vmapreduce(f, +, A)
119120
@inline vsum(A::AbstractArray{T}) where {T<:NativeTypes} = vsum(identity, A)
120121

121122
length_one_axis(::Base.OneTo) = Base.OneTo(1)

0 commit comments

Comments
 (0)