Skip to content

Commit d836395

Browse files
committed
Allow loop variables to be used in vectorized comparisons, and fix bug related to incorrectly eliminated stores.
1 parent a361f17 commit d836395

File tree

7 files changed

+62
-28
lines changed

7 files changed

+62
-28
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.6.26"
4+
version = "0.6.27"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -13,10 +13,10 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1313

1414
[compat]
1515
OffsetArrays = "1"
16-
SIMDPirates = "0.7.6"
16+
SIMDPirates = "0.7.7"
1717
SLEEFPirates = "0.4"
1818
UnPack = "0"
19-
VectorizationBase = "0.9.6"
19+
VectorizationBase = "0.10"
2020
julia = "1.1"
2121

2222
[extras]

benchmark/directcalljit.f90

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,52 @@
11
module jitmul
22

3-
include "/opt/intel/mkl/include/mkl_direct_call.fi"
4-
53
use ISO_C_BINDING
4+
use mkl_service
65
implicit none
76

8-
contains
7+
include "/opt/intel/mkl/include/mkl_direct_call.fi"
8+
! include "/opt/intel/mkl/include/mkl_service.fi"
9+
! include "/opt/intel/mkl/include/mkl.fi"
910

11+
contains
12+
subroutine set_num_threads(N) bind(C, name = "set_num_threads")
13+
integer(C_int32_t) :: N
14+
call mkl_set_num_threads(N)
15+
end subroutine set_num_threads
16+
1017
! subroutine dgemmjit(C,A,B,M,K,N,alpha,beta) bind(C, name = "dgemmjit")
18+
subroutine sgemmjit(C,A,B,M,K,N,At,Bt) bind(C, name = "sgemmjit")
19+
integer(C_int32_t), intent(in) :: M, K, N
20+
integer(C_int8_t), intent(in) :: At, Bt
21+
real(C_float), parameter :: alpha = 1.0e0, beta = 0.0e0
22+
! real(C_float), intent(in) :: alpha, beta
23+
real(C_float), dimension(M,K), intent(in) :: A
24+
real(C_float), dimension(K,N), intent(in) :: B
25+
real(C_float), dimension(M,N), intent(out) :: C
26+
character :: Atc, Btc
27+
! call mkl_set_threading_layer(MKL_THREADING_SEQUENTIAL)
28+
if (At == 1_C_int8_t) then
29+
Atc = 'T'
30+
else
31+
Atc = 'N'
32+
end if
33+
if (Bt == 1_C_int8_t) then
34+
Btc = 'T'
35+
else
36+
Btc = 'N'
37+
end if
38+
call sgemm(Atc, Btc, M, N, K, alpha, A, M, B, K, beta, C, M)
39+
end subroutine sgemmjit
1140
subroutine dgemmjit(C,A,B,M,K,N,At,Bt) bind(C, name = "dgemmjit")
1241
integer(C_int32_t), intent(in) :: M, K, N
1342
integer(C_int8_t), intent(in) :: At, Bt
14-
real(C_double), parameter :: alpha = 1.0D0, beta = 0.0D0
43+
real(C_double), parameter :: alpha = 1.0d0, beta = 0.0d0
1544
! real(C_double), intent(in) :: alpha, beta
1645
real(C_double), dimension(M,K), intent(in) :: A
1746
real(C_double), dimension(K,N), intent(in) :: B
1847
real(C_double), dimension(M,N), intent(out) :: C
1948
character :: Atc, Btc
49+
! call mkl_set_threading_layer(MKL_THREADING_SEQUENTIAL)
2050
if (At == 1_C_int8_t) then
2151
Atc = 'T'
2252
else

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector
77
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
88
AbstractColumnMajorStridedPointer, AbstractRowMajorStridedPointer, AbstractSparseStridedPointer, AbstractStaticStridedPointer,
99
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct,
10-
maybestaticfirst, maybestaticlast
10+
maybestaticfirst, maybestaticlast, scalar_less, scalar_greater
1111
using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod,
1212
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
1313
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, sizeequivalentfloat, sizeequivalentint, #prefetch,

src/add_stores.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,18 @@ function add_store!(
3535
pvar = name(parent)
3636
id = length(ls.operations)
3737
# try to cse store, by replacing the previous one
38-
ref = mpref.mref.ref
38+
mref = mpref.mref
3939
add_pvar = true
4040
for opp operations(ls)
4141
isstore(opp) || continue
42-
if ref == opp.ref.ref
42+
if mref == opp.ref
4343
id = opp.identifier
44+
add_pvar = false
4445
break
4546
end
46-
add_pvar &= (name(first(parents(opp))) != pvar)
47+
# add_pvar &= (name(first(parents(opp))) != pvar)
4748
end
49+
# @show add_pvar
4850
pushfirst!(vparents, parent)
4951
update_deps!(ldref, reduceddeps, parent)
5052
op = Operation( id, name(mpref), elementbytes, :setindex!, memstore, mpref )

src/graphs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ function vec_looprange(loop::Loop, W::Symbol, UF::Int, mangledname::Symbol)
6666
Expr(:call, lv(:valsub), W, 2)
6767
end
6868
if loop.stopexact # split for type stability
69-
Expr(:call, :<, mangledname, Expr(:call, :-, loop.stophint, incr))
69+
Expr(:call, lv(:scalar_less), mangledname, Expr(:call, :-, loop.stophint, incr))
7070
else
71-
Expr(:call, :<, mangledname, Expr(:call, :-, loop.stopsym, incr))
71+
Expr(:call, lv(:scalar_less), mangledname, Expr(:call, :-, loop.stopsym, incr))
7272
end
7373
end
7474
function looprange(loop::Loop, incr::Int, mangledname::Symbol)
7575
incr = 2 - incr
7676
if iszero(incr)
77-
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint : loop.stopsym)
77+
Expr(:call, lv(:scalar_less), mangledname, loop.stopexact ? loop.stophint : loop.stopsym)
7878
else
79-
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint + incr : Expr(:call, :+, loop.stopsym, incr))
79+
Expr(:call, lv(:scalar_less), mangledname, loop.stopexact ? loop.stophint + incr : Expr(:call, :+, loop.stopsym, incr))
8080
end
8181
end
8282
function terminatecondition(

src/lower_store.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using VectorizationBase: vnoaliasstore!
2-
# const STOREOP = :vnoaliasstore!
2+
33

44
@inline vstoreadditivereduce!(args...) = vnoaliasstore!(args...)
55
@inline vstoremultiplicativevereduce!(args...) = vnoaliasstore!(args...)
@@ -21,16 +21,18 @@ function storeinstr(op::Operation)
2121
if instruction(opp).instr === :identity
2222
opp = first(parents(opp))
2323
end
24+
defaultstoreop = :vnoaliasstore!
25+
# defaultstoreop = :vstore!
2426
instr = if iszero(length(reduceddependencies(opp)))
25-
:vnoaliasstore!
27+
defaultstoreop
2628
else
2729
instr_class = reduction_instruction_class(instruction(opp))
2830
if instr_class === ADDITIVE_IN_REDUCTIONS
2931
:vstoreadditivereduce!
3032
elseif instr_class === MULTIPLICATIVE_IN_REDUCTIONS
3133
:vstoremultiplicativevereduce!
3234
else #FIXME
33-
:vnoaliasstore!
35+
defaultstoreop
3436
end
3537
end
3638
lv(instr)

src/lowering.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,11 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
191191
else
192192
Expr(:call, :-, loop.stopsym, Expr(:call, lv(:valmul), ls.W, UFt))
193193
end
194-
Expr(:call, :>, loopsym, itercount)
194+
Expr(:call, lv(:scalar_greater), loopsym, itercount)
195195
elseif loop.stopexact
196-
Expr(:call, :>, loopsym, loop.stophint - UFt)
196+
Expr(:call, lv(:scalar_greater), loopsym, loop.stophint - UFt)
197197
else
198-
Expr(:call, :>, loopsym, Expr(:call, :-, loop.stopsym, UFt))
198+
Expr(:call, lv(:scalar_greater), loopsym, Expr(:call, :-, loop.stopsym, UFt))
199199
end
200200
ust = nisunrolled ? UnrollSpecification(us, UFt, T) : UnrollSpecification(us, U, UFt)
201201
remblocknew = Expr(:elseif, comparison, lower_block(ls, ust, n, remmask, UFt))
@@ -244,15 +244,15 @@ function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::
244244
push!(ifq.args, loopq)
245245
reduce_range!(ifq, ls, Ulow, Uhigh)
246246
comparison = if isstaticloop(unrolledloop)
247-
Expr(:call, :<, length(unrolledloop), Expr(:call, lv(:valmul), ls.W, Uhigh))
247+
Expr(:call, lv(:scalar_less), length(unrolledloop), Expr(:call, lv(:valmul), ls.W, Uhigh))
248248
elseif unrolledloop.starthint == 1
249-
Expr(:call, :<, unrolledloop.stopsym, Expr(:call, lv(:valmul), ls.W, Uhigh))
249+
Expr(:call, lv(:scalar_less), unrolledloop.stopsym, Expr(:call, lv(:valmul), ls.W, Uhigh))
250250
elseif unrolledloop.startexact
251-
Expr(:call, :<, Expr(:call, :-, unrolledloop.stopsym, unrolledloop.starthint-1), Expr(:call, lv(:valmul), ls.W, Uhigh))
251+
Expr(:call, lv(:scalar_less), Expr(:call, :-, unrolledloop.stopsym, unrolledloop.starthint-1), Expr(:call, lv(:valmul), ls.W, Uhigh))
252252
elseif unrolledloop.stopexact
253-
Expr(:call, :<, Expr(:call, :-, unrolledloop.stophint+1, unrolledloop.sartsym), Expr(:call, lv(:valmul), ls.W, Uhigh))
253+
Expr(:call, lv(:scalar_less), Expr(:call, :-, unrolledloop.stophint+1, unrolledloop.sartsym), Expr(:call, lv(:valmul), ls.W, Uhigh))
254254
else# both are given by symbols
255-
Expr(:call, :<, Expr(:call, :-, unrolledloop.stopsym, Expr(:call,:-,unrolledloop.startsym)), Expr(:call, lv(:valmul), ls.W, Uhigh))
255+
Expr(:call, lv(:scalar_less), Expr(:call, :-, unrolledloop.stopsym, Expr(:call,:-,unrolledloop.startsym)), Expr(:call, lv(:valmul), ls.W, Uhigh))
256256
end
257257
ncomparison = Expr(:call, :!, comparison)
258258
Expr(:if, ncomparison, ifq)
@@ -307,9 +307,9 @@ function determine_width(ls::LoopSet, vectorized::Symbol)
307307
end
308308
function init_remblock(unrolledloop::Loop, unrolled::Symbol = unrolledloop.itersymbol)
309309
condition = if unrolledloop.stopexact
310-
Expr(:call, :(>), unrolled, unrolledloop.stophint)
310+
Expr(:call, lv(:scalar_greater), unrolled, unrolledloop.stophint)
311311
else
312-
Expr(:call, :(>), unrolled, unrolledloop.stopsym)
312+
Expr(:call, lv(:scalar_greater), unrolled, unrolledloop.stopsym)
313313
end
314314
Expr(:if, condition, nothing)
315315
end

0 commit comments

Comments
 (0)