Skip to content

Commit 715ed8f

Browse files
committed
MKL does use 64-bit API, so fix benchmarks.
1 parent 272c856 commit 715ed8f

File tree

1 file changed

+34
-41
lines changed

1 file changed

+34
-41
lines changed

benchmark/loadsharedlibs.jl

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,74 +61,67 @@ end
6161
# end
6262

6363

64+
randa(::Type{T}, dim...) where {T} = rand(T, dim...)
65+
randa(::Type{T}, dim...) where {T <: Signed} = rand(T(-100):T(200), dim...)
66+
6467
using MKL_jll, OpenBLAS_jll
6568

6669
const libMKL = Libdl.dlopen(MKL_jll.libmkl_rt)
6770
const DGEMM_MKL = Libdl.dlsym(libMKL, :dgemm)
71+
const SGEMM_MKL = Libdl.dlsym(libMKL, :sgemm)
6872
const DGEMV_MKL = Libdl.dlsym(libMKL, :dgemv)
6973
const MKL_SET_NUM_THREADS = Libdl.dlsym(libMKL, :MKL_Set_Num_Threads)
7074

7175
const libOpenBLAS = Libdl.dlopen(OpenBLAS_jll.libopenblas)
7276
const DGEMM_OpenBLAS = Libdl.dlsym(libOpenBLAS, :dgemm_64_)
77+
const SGEMM_OpenBLAS = Libdl.dlsym(libOpenBLAS, :sgemm_64_)
7378
const DGEMV_OpenBLAS = Libdl.dlsym(libOpenBLAS, :dgemv_64_)
7479
const OPENBLAS_SET_NUM_THREADS = Libdl.dlsym(libOpenBLAS, :openblas_set_num_threads64_)
7580

7681
istransposed(x) = 'N'
7782
istransposed(x::Adjoint{<:Real}) = 'T'
7883
istransposed(x::Adjoint) = 'C'
7984
istransposed(x::Transpose) = 'T'
80-
function dgemmmkl!(C::AbstractMatrix{Float64}, A::AbstractMatrix{Float64}, B::AbstractMatrix{Float64})
81-
transA = istransposed(A)
82-
transB = istransposed(B)
83-
M, N = size(C); K = size(B, 1)
84-
M32 = M % Int32
85-
K32 = K % Int32
86-
N32 = N % Int32
87-
pA = parent(A); pB = parent(B)
88-
ldA = stride(pA, 2) % Int32
89-
ldB = stride(pB, 2) % Int32
90-
ldC = stride(C, 2) % Int32
91-
α = 1.0
92-
β = 0.0
93-
ccall(
94-
DGEMM_MKL, Cvoid,
95-
(Ref{UInt8}, Ref{UInt8}, Ref{Int32}, Ref{Int32}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}),
96-
transA, transB, M32, N32, K32, α, pA, ldA, pB, ldB, β, C, ldC
97-
)
98-
end
99-
function dgemmopenblas!(C::AbstractMatrix{Float64}, A::AbstractMatrix{Float64}, B::AbstractMatrix{Float64})
100-
transA = istransposed(A)
101-
transB = istransposed(B)
102-
M, N = size(C); K = size(B, 1)
103-
pA = parent(A); pB = parent(B)
104-
ldA = stride(pA, 2)
105-
ldB = stride(pB, 2)
106-
ldC = stride(C, 2)
107-
α = 1.0
108-
β = 0.0
109-
ccall(
110-
DGEMM_OpenBLAS, Cvoid,
111-
(Ref{UInt8}, Ref{UInt8}, Ref{Int64}, Ref{Int64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}),
112-
transA, transB, M, N, K, α, pA, ldA, pB, ldB, β, C, ldC
113-
)
85+
for (lib,f) [(:GEMM_MKL,:gemmmkl!), (:GEMM_OpenBLAS,:gemmopenblas!)]
86+
for (T,prefix) [(Float32,:S),(Float64,:D)]
87+
fm = Symbol(prefix, lib)
88+
@eval begin
89+
function $f(C::AbstractMatrix{$T}, A::AbstractMatrix{$T}, B::AbstractMatrix{$T})
90+
transA = istransposed(A)
91+
transB = istransposed(B)
92+
M, N = size(C); K = size(B, 1)
93+
pA = parent(A); pB = parent(B)
94+
ldA = stride(pA, 2)
95+
ldB = stride(pB, 2)
96+
ldC = stride(C, 2)
97+
α = one($T)
98+
β = zero($T)
99+
ccall(
100+
$fm, Cvoid,
101+
(Ref{UInt8}, Ref{UInt8}, Ref{Int64}, Ref{Int64}, Ref{Int64}, Ref{$T}, Ref{$T},
102+
Ref{Int64}, Ref{$T}, Ref{Int64}, Ref{$T}, Ref{$T}, Ref{Int64}),
103+
transA, transB, M, N, K, α, pA, ldA, pB, ldB, β, C, ldC
104+
)
105+
end
106+
end
107+
end
114108
end
115109
mkl_set_num_threads(N::Integer) = ccall(MKL_SET_NUM_THREADS, Cvoid, (Int32,), N % Int32)
116110
mkl_set_num_threads(1)
117111
openblas_set_num_threads(N::Integer) = ccall(OPENBLAS_SET_NUM_THREADS, Cvoid, (Int64,), N)
118112
openblas_set_num_threads(1)
113+
119114
function dgemvmkl!(y::AbstractVector{Float64}, A::AbstractMatrix{Float64}, x::AbstractVector{Float64}, α = 1.0, β = 0.0)
120115
transA = istransposed(A)
121116
pA = parent(A)
122117
M, N = size(pA)
123-
M32 = M % Int32
124-
N32 = N % Int32
125-
ldA = stride(pA, 2) % Int32
126-
incx = LinearAlgebra.stride1(x) % Int32
127-
incy = LinearAlgebra.stride1(y) % Int32
118+
ldA = stride(pA, 2)
119+
incx = LinearAlgebra.stride1(x)
120+
incy = LinearAlgebra.stride1(y)
128121
ccall(
129122
DGEMV_MKL, Cvoid,
130-
(Ref{UInt8}, Ref{Int32}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}),
131-
transA, M32, N32, α, pA, ldA, x, incx, β, y, incy
123+
(Ref{UInt8}, Ref{Int64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}),
124+
transA, M, N, α, pA, ldA, x, incx, β, y, incy
132125
)
133126
end
134127
function dgemvopenblas!(y::AbstractVector{Float64}, A::AbstractMatrix{Float64}, x::AbstractVector{Float64})

0 commit comments

Comments
 (0)