Skip to content

Commit d49eb5a

Browse files
authored
Merge pull request #202 from JuliaGPU/tb/mul_redesign
Implement new mul! API.
2 parents 7d7ef87 + 61f544c commit d49eb5a

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/blas.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ function blasbuffer(A)
99
error("$(typeof(A)) doesn't support BLAS operations")
1010
end
1111

12-
for T in (Float32, Float64, ComplexF32, ComplexF64)
12+
for elty in (Float32, Float64, ComplexF32, ComplexF64)
13+
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
1314
@eval begin
1415
function BLAS.gemm!(
15-
transA::Char, transB::Char, alpha::$T,
16-
A::GPUVecOrMat{$T}, B::GPUVecOrMat{$T},
17-
beta::$T, C::GPUVecOrMat{$T}
16+
transA::AbstractChar, transB::AbstractChar, alpha::$T,
17+
A::GPUVecOrMat{$elty}, B::GPUVecOrMat{$elty},
18+
beta::$T, C::GPUVecOrMat{$elty}
1819
)
1920
blasmod = blas_module(A)
2021
result = blasmod.gemm!(
@@ -54,8 +55,9 @@ end
5455

5556

5657
for elty in (Float32, Float64, ComplexF32, ComplexF64)
58+
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
5759
@eval begin
58-
function BLAS.gemv!(trans::Char, alpha::($elty), A::GPUVecOrMat{$elty}, X::GPUVector{$elty}, beta::($elty), Y::GPUVector{$elty})
60+
function BLAS.gemv!(trans::AbstractChar, alpha::$T, A::GPUVecOrMat{$elty}, X::GPUVector{$elty}, beta::$T, Y::GPUVector{$elty})
5961
m, n = size(A, 1), size(A, 2)
6062
if trans == 'N' && (length(X) != n || length(Y) != m)
6163
throw(DimensionMismatch("A has dimensions $(size(A)), X has length $(length(X)) and Y has length $(length(Y))"))
@@ -93,7 +95,7 @@ end
9395

9496
for elty in (Float32, Float64, ComplexF32, ComplexF64)
9597
@eval begin
96-
function BLAS.gbmv!(trans::Char, m::Int, kl::Int, ku::Int, alpha::($elty), A::GPUMatrix{$elty}, X::GPUVector{$elty}, beta::($elty), Y::GPUVector{$elty})
98+
function BLAS.gbmv!(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::GPUMatrix{$elty}, X::GPUVector{$elty}, beta::($elty), Y::GPUVector{$elty})
9799
n = size(A, 2)
98100
if trans == 'N' && (length(X) != n || length(Y) != m)
99101
throw(DimensionMismatch("A has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))

0 commit comments

Comments
 (0)