|
61 | 61 | # end
|
62 | 62 |
|
63 | 63 |
|
| 64 | +randa(::Type{T}, dim...) where {T} = rand(T, dim...) |
| 65 | +randa(::Type{T}, dim...) where {T <: Signed} = rand(T(-100):T(200), dim...) |
| 66 | + |
64 | 67 | using MKL_jll, OpenBLAS_jll
|
65 | 68 |
|
66 | 69 | const libMKL = Libdl.dlopen(MKL_jll.libmkl_rt)
|
67 | 70 | const DGEMM_MKL = Libdl.dlsym(libMKL, :dgemm)
|
| 71 | +const SGEMM_MKL = Libdl.dlsym(libMKL, :sgemm) |
68 | 72 | const DGEMV_MKL = Libdl.dlsym(libMKL, :dgemv)
|
69 | 73 | const MKL_SET_NUM_THREADS = Libdl.dlsym(libMKL, :MKL_Set_Num_Threads)
|
70 | 74 |
|
71 | 75 | const libOpenBLAS = Libdl.dlopen(OpenBLAS_jll.libopenblas)
|
72 | 76 | const DGEMM_OpenBLAS = Libdl.dlsym(libOpenBLAS, :dgemm_64_)
|
| 77 | +const SGEMM_OpenBLAS = Libdl.dlsym(libOpenBLAS, :sgemm_64_) |
73 | 78 | const DGEMV_OpenBLAS = Libdl.dlsym(libOpenBLAS, :dgemv_64_)
|
74 | 79 | const OPENBLAS_SET_NUM_THREADS = Libdl.dlsym(libOpenBLAS, :openblas_set_num_threads64_)
|
75 | 80 |
|
76 | 81 | istransposed(x) = 'N'
|
77 | 82 | istransposed(x::Adjoint{<:Real}) = 'T'
|
78 | 83 | istransposed(x::Adjoint) = 'C'
|
79 | 84 | istransposed(x::Transpose) = 'T'
|
80 |
| -function dgemmmkl!(C::AbstractMatrix{Float64}, A::AbstractMatrix{Float64}, B::AbstractMatrix{Float64}) |
81 |
| - transA = istransposed(A) |
82 |
| - transB = istransposed(B) |
83 |
| - M, N = size(C); K = size(B, 1) |
84 |
| - M32 = M % Int32 |
85 |
| - K32 = K % Int32 |
86 |
| - N32 = N % Int32 |
87 |
| - pA = parent(A); pB = parent(B) |
88 |
| - ldA = stride(pA, 2) % Int32 |
89 |
| - ldB = stride(pB, 2) % Int32 |
90 |
| - ldC = stride(C, 2) % Int32 |
91 |
| - α = 1.0 |
92 |
| - β = 0.0 |
93 |
| - ccall( |
94 |
| - DGEMM_MKL, Cvoid, |
95 |
| - (Ref{UInt8}, Ref{UInt8}, Ref{Int32}, Ref{Int32}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}), |
96 |
| - transA, transB, M32, N32, K32, α, pA, ldA, pB, ldB, β, C, ldC |
97 |
| - ) |
98 |
| -end |
99 |
| -function dgemmopenblas!(C::AbstractMatrix{Float64}, A::AbstractMatrix{Float64}, B::AbstractMatrix{Float64}) |
100 |
| - transA = istransposed(A) |
101 |
| - transB = istransposed(B) |
102 |
| - M, N = size(C); K = size(B, 1) |
103 |
| - pA = parent(A); pB = parent(B) |
104 |
| - ldA = stride(pA, 2) |
105 |
| - ldB = stride(pB, 2) |
106 |
| - ldC = stride(C, 2) |
107 |
| - α = 1.0 |
108 |
| - β = 0.0 |
109 |
| - ccall( |
110 |
| - DGEMM_OpenBLAS, Cvoid, |
111 |
| - (Ref{UInt8}, Ref{UInt8}, Ref{Int64}, Ref{Int64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}), |
112 |
| - transA, transB, M, N, K, α, pA, ldA, pB, ldB, β, C, ldC |
113 |
| - ) |
| 85 | +for (lib,f) ∈ [(:GEMM_MKL,:gemmmkl!), (:GEMM_OpenBLAS,:gemmopenblas!)] |
| 86 | + for (T,prefix) ∈ [(Float32,:S),(Float64,:D)] |
| 87 | + fm = Symbol(prefix, lib) |
| 88 | + @eval begin |
| 89 | + function $f(C::AbstractMatrix{$T}, A::AbstractMatrix{$T}, B::AbstractMatrix{$T}) |
| 90 | + transA = istransposed(A) |
| 91 | + transB = istransposed(B) |
| 92 | + M, N = size(C); K = size(B, 1) |
| 93 | + pA = parent(A); pB = parent(B) |
| 94 | + ldA = stride(pA, 2) |
| 95 | + ldB = stride(pB, 2) |
| 96 | + ldC = stride(C, 2) |
| 97 | + α = one($T) |
| 98 | + β = zero($T) |
| 99 | + ccall( |
| 100 | + $fm, Cvoid, |
| 101 | + (Ref{UInt8}, Ref{UInt8}, Ref{Int64}, Ref{Int64}, Ref{Int64}, Ref{$T}, Ref{$T}, |
| 102 | + Ref{Int64}, Ref{$T}, Ref{Int64}, Ref{$T}, Ref{$T}, Ref{Int64}), |
| 103 | + transA, transB, M, N, K, α, pA, ldA, pB, ldB, β, C, ldC |
| 104 | + ) |
| 105 | + end |
| 106 | + end |
| 107 | + end |
114 | 108 | end
|
115 | 109 | mkl_set_num_threads(N::Integer) = ccall(MKL_SET_NUM_THREADS, Cvoid, (Int32,), N % Int32)
|
116 | 110 | mkl_set_num_threads(1)
|
117 | 111 | openblas_set_num_threads(N::Integer) = ccall(OPENBLAS_SET_NUM_THREADS, Cvoid, (Int64,), N)
|
118 | 112 | openblas_set_num_threads(1)
|
| 113 | + |
119 | 114 | function dgemvmkl!(y::AbstractVector{Float64}, A::AbstractMatrix{Float64}, x::AbstractVector{Float64}, α = 1.0, β = 0.0)
|
120 | 115 | transA = istransposed(A)
|
121 | 116 | pA = parent(A)
|
122 | 117 | M, N = size(pA)
|
123 |
| - M32 = M % Int32 |
124 |
| - N32 = N % Int32 |
125 |
| - ldA = stride(pA, 2) % Int32 |
126 |
| - incx = LinearAlgebra.stride1(x) % Int32 |
127 |
| - incy = LinearAlgebra.stride1(y) % Int32 |
| 118 | + ldA = stride(pA, 2) |
| 119 | + incx = LinearAlgebra.stride1(x) |
| 120 | + incy = LinearAlgebra.stride1(y) |
128 | 121 | ccall(
|
129 | 122 | DGEMV_MKL, Cvoid,
|
130 |
| - (Ref{UInt8}, Ref{Int32}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Int32}, Ref{Float64}, Ref{Float64}, Ref{Int32}), |
131 |
| - transA, M32, N32, α, pA, ldA, x, incx, β, y, incy |
| 123 | + (Ref{UInt8}, Ref{Int64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Int64}, Ref{Float64}, Ref{Float64}, Ref{Int64}), |
| 124 | + transA, M, N, α, pA, ldA, x, incx, β, y, incy |
132 | 125 | )
|
133 | 126 | end
|
134 | 127 | function dgemvopenblas!(y::AbstractVector{Float64}, A::AbstractMatrix{Float64}, x::AbstractVector{Float64})
|
|
0 commit comments