Skip to content

Commit 4c98563

Browse files
committed
Remove BLAS methods.
Doing so isn't portable, not every back-end might have a BLAS, and we should be implementing LinearAlgebra instead.
1 parent 2f85926 commit 4c98563

File tree

5 files changed

+0
-167
lines changed

5 files changed

+0
-167
lines changed

docs/src/interface.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,4 @@ should provide implementations of the following interfaces:
6060
GPUArrays.backend
6161
GPUArrays.device
6262
GPUArrays.unsafe_reinterpret
63-
GPUArrays.blas_module
64-
GPUArrays.blasbuffer
6563
```

src/host/linalg.jl

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,5 @@
11
# integration with LinearAlgebra stdlib
22

3-
## low-level BLAS calls
4-
5-
function blas_module(A)
6-
error("$(typeof(A)) doesn't support BLAS operations")
7-
end
8-
function blasbuffer(A)
9-
error("$(typeof(A)) doesn't support BLAS operations")
10-
end
11-
12-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
13-
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
14-
@eval begin
15-
function BLAS.gemm!(
16-
transA::AbstractChar, transB::AbstractChar, alpha::$T,
17-
A::AbstractGPUVecOrMat{$elty}, B::AbstractGPUVecOrMat{$elty},
18-
beta::$T, C::AbstractGPUVecOrMat{$elty}
19-
)
20-
blasmod = blas_module(A)
21-
result = blasmod.gemm!(
22-
transA, transB, alpha,
23-
blasbuffer(A), blasbuffer(B), beta, blasbuffer(C)
24-
)
25-
C
26-
end
27-
end
28-
end
29-
30-
for elty in (Float64, Float32)
31-
@eval begin
32-
function BLAS.scal!(
33-
n::Integer, DA::$elty,
34-
DX::AbstractGPUArray{$elty, N}, incx::Integer
35-
) where N
36-
blasmod = blas_module(DX)
37-
blasmod.scal!(n, DA, blasbuffer(DX), incx)
38-
DX
39-
end
40-
end
41-
end
42-
43-
LinearAlgebra.rmul!(s::Number, X::AbstractGPUArray) = rmul!(X, s)
44-
function LinearAlgebra.rmul!(X::AbstractGPUArray{T}, s::Number) where T <: BLAS.BlasComplex
45-
R = typeof(real(zero(T)))
46-
N = 2*length(X)
47-
buff = unsafe_reinterpret(R, X, (N,))
48-
BLAS.scal!(N, R(s), buff, 1)
49-
X
50-
end
51-
function LinearAlgebra.rmul!(X::AbstractGPUArray{T}, s::Number) where T <: Union{Float32, Float64}
52-
BLAS.scal!(length(X), T(s), X, 1)
53-
X
54-
end
55-
56-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
57-
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
58-
@eval begin
59-
function BLAS.gemv!(trans::AbstractChar, alpha::$T, A::AbstractGPUVecOrMat{$elty}, X::AbstractGPUVector{$elty}, beta::$T, Y::AbstractGPUVector{$elty})
60-
m, n = size(A, 1), size(A, 2)
61-
if trans == 'N' && (length(X) != n || length(Y) != m)
62-
throw(DimensionMismatch("A has dimensions $(size(A)), X has length $(length(X)) and Y has length $(length(Y))"))
63-
elseif trans == 'C' && (length(X) != m || length(Y) != n)
64-
throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
65-
elseif trans == 'T' && (length(X) != m || length(Y) != n)
66-
throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
67-
end
68-
blasmod = blas_module(A)
69-
blasmod.gemv!(
70-
trans, alpha,
71-
blasbuffer(A), blasbuffer(X), beta, blasbuffer(Y)
72-
)
73-
Y
74-
end
75-
end
76-
end
77-
78-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
79-
@eval begin
80-
function BLAS.axpy!(
81-
alpha::Number, x::AbstractGPUArray{$elty}, y::AbstractGPUArray{$elty}
82-
)
83-
if length(x) != length(y)
84-
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
85-
end
86-
blasmod = blas_module(x)
87-
blasmod.axpy!($elty(alpha), blasbuffer(vec(x)), blasbuffer(vec(y)))
88-
y
89-
end
90-
end
91-
end
92-
93-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
94-
@eval begin
95-
function BLAS.gbmv!(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::AbstractGPUMatrix{$elty}, X::AbstractGPUVector{$elty}, beta::($elty), Y::AbstractGPUVector{$elty})
96-
n = size(A, 2)
97-
if trans == 'N' && (length(X) != n || length(Y) != m)
98-
throw(DimensionMismatch("A has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
99-
elseif trans == 'C' && (length(X) != m || length(Y) != n)
100-
throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
101-
elseif trans == 'T' && (length(X) != m || length(Y) != n)
102-
throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
103-
end
104-
blasmod = blas_module(A)
105-
blasmod.gbmv!(
106-
trans, m, kl, ku, alpha,
107-
blasbuffer(A), blasbuffer(X), beta, blasbuffer(Y)
108-
)
109-
Y
110-
end
111-
end
112-
end
113-
114-
115-
## high-level functionality
116-
1173
function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
1184
gpu_call(At, A) do ctx, At, A
1195
idx = @cartesianidx A ctx

src/reference.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,4 @@ to_device(ctx, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(ctx, x[])
269269
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
270270
reshape(reinterpret(T, A.data), size)
271271

272-
# linear algebra
273-
274-
using LinearAlgebra
275-
276-
GPUArrays.blas_module(::JLArray) = LinearAlgebra.BLAS
277-
GPUArrays.blasbuffer(A::JLArray) = A.data
278-
279272
end

test/testsuite.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ include("testsuite/mapreduce.jl")
3838
include("testsuite/broadcasting.jl")
3939
include("testsuite/linalg.jl")
4040
include("testsuite/fft.jl")
41-
include("testsuite/blas.jl")
4241
include("testsuite/random.jl")
4342

4443

@@ -55,7 +54,6 @@ function test(AT::Type{<:AbstractGPUArray})
5554
TestSuite.test_broadcasting(AT)
5655
TestSuite.test_linalg(AT)
5756
TestSuite.test_fft(AT)
58-
TestSuite.test_blas(AT)
5957
TestSuite.test_random(AT)
6058
end
6159

test/testsuite/blas.jl

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)