Skip to content

Commit 619d393

Browse files
authored
Merge pull request #234 from JuliaGPU/tb/blas
Remove BLAS methods/tests
2 parents 454abc6 + 4c98563 commit 619d393

File tree

6 files changed

+7
-173
lines changed

6 files changed

+7
-173
lines changed

docs/src/interface.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,4 @@ should provide implementations of the following interfaces:
6060
GPUArrays.backend
6161
GPUArrays.device
6262
GPUArrays.unsafe_reinterpret
63-
GPUArrays.blas_module
64-
GPUArrays.blasbuffer
6563
```

src/device/memory.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@ export @LocalMemory
55

66
## thread-local array
77

8-
const shmem_counter = Ref{Int}(0)
9-
108
"""
119
Creates a local static memory shared inside one block.
1210
Equivalent to `__local` of OpenCL or `__shared__ (<variable>)` of CUDA.
1311
"""
1412
macro LocalMemory(ctx, T, N)
15-
id = (shmem_counter[] += 1)
13+
id = gensym("local_memory")
1614
quote
17-
LocalMemory($(esc(ctx)), $(esc(T)), Val($(esc(N))), Val($id))
15+
LocalMemory($(esc(ctx)), $(esc(T)), Val($(esc(N))), Val($(QuoteNode(id))))
1816
end
1917
end
2018

src/host/linalg.jl

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,5 @@
11
# integration with LinearAlgebra stdlib
22

3-
## low-level BLAS calls
4-
5-
function blas_module(A)
6-
error("$(typeof(A)) doesn't support BLAS operations")
7-
end
8-
function blasbuffer(A)
9-
error("$(typeof(A)) doesn't support BLAS operations")
10-
end
11-
12-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
13-
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
14-
@eval begin
15-
function BLAS.gemm!(
16-
transA::AbstractChar, transB::AbstractChar, alpha::$T,
17-
A::AbstractGPUVecOrMat{$elty}, B::AbstractGPUVecOrMat{$elty},
18-
beta::$T, C::AbstractGPUVecOrMat{$elty}
19-
)
20-
blasmod = blas_module(A)
21-
result = blasmod.gemm!(
22-
transA, transB, alpha,
23-
blasbuffer(A), blasbuffer(B), beta, blasbuffer(C)
24-
)
25-
C
26-
end
27-
end
28-
end
29-
30-
for elty in (Float64, Float32)
31-
@eval begin
32-
function BLAS.scal!(
33-
n::Integer, DA::$elty,
34-
DX::AbstractGPUArray{$elty, N}, incx::Integer
35-
) where N
36-
blasmod = blas_module(DX)
37-
blasmod.scal!(n, DA, blasbuffer(DX), incx)
38-
DX
39-
end
40-
end
41-
end
42-
43-
LinearAlgebra.rmul!(s::Number, X::AbstractGPUArray) = rmul!(X, s)
44-
function LinearAlgebra.rmul!(X::AbstractGPUArray{T}, s::Number) where T <: BLAS.BlasComplex
45-
R = typeof(real(zero(T)))
46-
N = 2*length(X)
47-
buff = unsafe_reinterpret(R, X, (N,))
48-
BLAS.scal!(N, R(s), buff, 1)
49-
X
50-
end
51-
function LinearAlgebra.rmul!(X::AbstractGPUArray{T}, s::Number) where T <: Union{Float32, Float64}
52-
BLAS.scal!(length(X), T(s), X, 1)
53-
X
54-
end
55-
56-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
57-
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
58-
@eval begin
59-
function BLAS.gemv!(trans::AbstractChar, alpha::$T, A::AbstractGPUVecOrMat{$elty}, X::AbstractGPUVector{$elty}, beta::$T, Y::AbstractGPUVector{$elty})
60-
m, n = size(A, 1), size(A, 2)
61-
if trans == 'N' && (length(X) != n || length(Y) != m)
62-
throw(DimensionMismatch("A has dimensions $(size(A)), X has length $(length(X)) and Y has length $(length(Y))"))
63-
elseif trans == 'C' && (length(X) != m || length(Y) != n)
64-
throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
65-
elseif trans == 'T' && (length(X) != m || length(Y) != n)
66-
throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
67-
end
68-
blasmod = blas_module(A)
69-
blasmod.gemv!(
70-
trans, alpha,
71-
blasbuffer(A), blasbuffer(X), beta, blasbuffer(Y)
72-
)
73-
Y
74-
end
75-
end
76-
end
77-
78-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
79-
@eval begin
80-
function BLAS.axpy!(
81-
alpha::Number, x::AbstractGPUArray{$elty}, y::AbstractGPUArray{$elty}
82-
)
83-
if length(x) != length(y)
84-
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
85-
end
86-
blasmod = blas_module(x)
87-
blasmod.axpy!($elty(alpha), blasbuffer(vec(x)), blasbuffer(vec(y)))
88-
y
89-
end
90-
end
91-
end
92-
93-
for elty in (Float32, Float64, ComplexF32, ComplexF64)
94-
@eval begin
95-
function BLAS.gbmv!(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::AbstractGPUMatrix{$elty}, X::AbstractGPUVector{$elty}, beta::($elty), Y::AbstractGPUVector{$elty})
96-
n = size(A, 2)
97-
if trans == 'N' && (length(X) != n || length(Y) != m)
98-
throw(DimensionMismatch("A has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
99-
elseif trans == 'C' && (length(X) != m || length(Y) != n)
100-
throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
101-
elseif trans == 'T' && (length(X) != m || length(Y) != n)
102-
throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
103-
end
104-
blasmod = blas_module(A)
105-
blasmod.gbmv!(
106-
trans, m, kl, ku, alpha,
107-
blasbuffer(A), blasbuffer(X), beta, blasbuffer(Y)
108-
)
109-
Y
110-
end
111-
end
112-
end
113-
114-
115-
## high-level functionality
116-
1173
function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
1184
gpu_call(At, A) do ctx, At, A
1195
idx = @cartesianidx A ctx

src/reference.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# reference implementation on the CPU
22

3+
# note that most of the code in this file serves to define a functional array type,
4+
# the actual implementation of GPUArrays-interfaces is much more limited.
5+
36
module JLArrays
47

58
using GPUArrays
@@ -25,9 +28,9 @@ struct JLBackend <: AbstractGPUBackend end
2528
mutable struct JLKernelContext <: AbstractKernelContext
2629
blockdim::Int
2730
griddim::Int
28-
2931
blockidx::Int
3032
threadidx::Int
33+
3134
localmem_counter::Int
3235
localmems::Vector{Vector{Array}}
3336
end
@@ -169,7 +172,7 @@ Base.size(x::JLArray) = x.dims
169172
Base.sizeof(x::JLArray) = Base.elsize(x) * length(x)
170173

171174

172-
## interop with other arrays
175+
## interop with Julia arrays
173176

174177
JLArray{T,N}(x::AbstractArray{S,N}) where {T,N,S} =
175178
JLArray{T,N}(convert(Array{T}, x), size(x))
@@ -266,11 +269,4 @@ to_device(ctx, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(ctx, x[])
266269
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
267270
reshape(reinterpret(T, A.data), size)
268271

269-
# linear algebra
270-
271-
using LinearAlgebra
272-
273-
GPUArrays.blas_module(::JLArray) = LinearAlgebra.BLAS
274-
GPUArrays.blasbuffer(A::JLArray) = A.data
275-
276272
end

test/testsuite.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ include("testsuite/mapreduce.jl")
3838
include("testsuite/broadcasting.jl")
3939
include("testsuite/linalg.jl")
4040
include("testsuite/fft.jl")
41-
include("testsuite/blas.jl")
4241
include("testsuite/random.jl")
4342

4443

@@ -55,7 +54,6 @@ function test(AT::Type{<:AbstractGPUArray})
5554
TestSuite.test_broadcasting(AT)
5655
TestSuite.test_linalg(AT)
5756
TestSuite.test_fft(AT)
58-
TestSuite.test_blas(AT)
5957
TestSuite.test_random(AT)
6058
end
6159

test/testsuite/blas.jl

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)