Skip to content

Commit 0db912b

Browse files
committed
Couple non-AVX2 fixes/tweaks
1 parent 5e44c14 commit 0db912b

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

src/modeling/determinestrategy.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
23
# function indexappearences(op::Operation, s::Symbol)
34
# s ∉ loopdependencies(op) && return 0
45
# appearences = 0
@@ -95,14 +96,15 @@ function cost(ls::LoopSet, op::Operation, (u₁,u₂)::Tuple{Symbol,Symbol}, vlo
9596
shifter = 2
9697
offset = 0.5reg_size(ls) / cache_lnsze(ls)
9798
end
98-
if !rejectcurly(op) && (((contigind === CONSTANTZEROINDEX) && ((length(indices) > 1) && (indices[2] === u₁) || (indices[2] === u₂))) ||
99-
((u₁ === contigind) | (u₂ === contigind)))
99+
if shifter > 1 &&
100+
(!rejectcurly(op) && (((contigind === CONSTANTZEROINDEX) && ((length(indices) > 1) && (indices[2] === u₁) || (indices[2] === u₂))) ||
101+
((u₁ === contigind) | (u₂ === contigind))))
100102

101103
shifter -= 1
102104
offset = 0.5reg_size(ls) / cache_lnsze(ls)
103105
end
104106
r = 1 << shifter
105-
srt *= r + offset
107+
srt = srt*r + offset
106108
sl *= r
107109
elseif isload(op) & (length(loopdependencies(op)) > 1)# vmov(a/u)pd
108110
# penalize vectorized loads with more than 1 loopdep

src/simdfunctionals/filter.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ function vfilter!(f::F, x::Vector{T}, y::AbstractArray{T}) where {F,T <: NativeT
77
j = 0
88
st = VectorizationBase.static_sizeof(T)
99
zero_index = MM(W, Static(0), st)
10+
incr = W * VectorizationBase.static_sizeof(T)
1011
GC.@preserve x y begin
1112
# ptr_x = llvmptr(x); ptr_y = llvmptr(y)
1213
ptr_x = pointer(x); ptr_y = pointer(y)
1314
for _ 1:Nrep
1415
vy = VectorizationBase.__vload(ptr_y, zero_index, False(), register_size())
1516
mask = f(vy)
1617
VectorizationBase.compressstore!(gep(ptr_x, VectorizationBase.lazymul(st, j)), vy, mask)
17-
ptr_y = gep(ptr_y, register_size())
18+
ptr_y = gep(ptr_y, incr)
1819
j = vadd_fast(j, count_ones(mask))
1920
end
2021
rem_mask = VectorizationBase.mask(T, Nrem)

test/ifelsemasks.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,10 +561,15 @@ T = Float32
561561
a = rand(-10:10, 43);
562562
bit = a .> 0.5; bool = copyto!(Vector{Bool}(undef, length(bit)), bit);
563563
t = Bernoulli_logit(bit, a);
564-
@test isapprox(t, Bernoulli_logitavx(bit, a), atol = Int === Int32 ? 0.1 : 0)
565-
@test isapprox(t, Bernoulli_logit_avx(bit, a), atol = Int === Int32 ? 0.1 : 0)
566-
@test isapprox(t, Bernoulli_logitavx(bool, a), atol = Int === Int32 ? 0.1 : 0)
567-
@test isapprox(t, Bernoulli_logit_avx(bool, a), atol = Int === Int32 ? 0.1 : 0)
564+
@test isapprox(t, Bernoulli_logitavx(bit, a), atol = ifelse(Int === Int32, 0.1, 0.0))
565+
if VectorizationBase.pick_vector_width(eltype(a)) 4
566+
# @_avx isn't really expected to work with bits if you don't have AVX512
567+
# but it happens to work with AVX2 for this anyway, so may as well keep testing.
568+
# am ruling out non-avx2 with the `VectorizationBase.pick_vector_width(eltype(a)) ≥ 4` check
569+
@test isapprox(t, Bernoulli_logit_avx(bit, a), atol = ifelse(Int === Int32, 0.1, 0.0))
570+
end
571+
@test isapprox(t, Bernoulli_logitavx(bool, a), atol = ifelse(Int === Int32, 0.1, 0.0))
572+
@test isapprox(t, Bernoulli_logit_avx(bool, a), atol = ifelse(Int === Int32, 0.1, 0.0))
568573
a = rand(43);
569574
bit = a .> 0.5; bool = copyto!(Vector{Bool}(undef, length(bit)), bit);
570575
t = Bernoulli_logit(bit, a);

0 commit comments

Comments
 (0)