|
1 | 1 | # integration with LinearAlgebra stdlib
|
2 | 2 |
|
3 |
| -## low-level BLAS calls |
4 |
| - |
5 |
| -function blas_module(A) |
6 |
| - error("$(typeof(A)) doesn't support BLAS operations") |
7 |
| -end |
8 |
| -function blasbuffer(A) |
9 |
| - error("$(typeof(A)) doesn't support BLAS operations") |
10 |
| -end |
11 |
| - |
12 |
| -for elty in (Float32, Float64, ComplexF32, ComplexF64) |
13 |
| - T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty |
14 |
| - @eval begin |
15 |
| - function BLAS.gemm!( |
16 |
| - transA::AbstractChar, transB::AbstractChar, alpha::$T, |
17 |
| - A::AbstractGPUVecOrMat{$elty}, B::AbstractGPUVecOrMat{$elty}, |
18 |
| - beta::$T, C::AbstractGPUVecOrMat{$elty} |
19 |
| - ) |
20 |
| - blasmod = blas_module(A) |
21 |
| - result = blasmod.gemm!( |
22 |
| - transA, transB, alpha, |
23 |
| - blasbuffer(A), blasbuffer(B), beta, blasbuffer(C) |
24 |
| - ) |
25 |
| - C |
26 |
| - end |
27 |
| - end |
28 |
| -end |
29 |
| - |
30 |
| -for elty in (Float64, Float32) |
31 |
| - @eval begin |
32 |
| - function BLAS.scal!( |
33 |
| - n::Integer, DA::$elty, |
34 |
| - DX::AbstractGPUArray{$elty, N}, incx::Integer |
35 |
| - ) where N |
36 |
| - blasmod = blas_module(DX) |
37 |
| - blasmod.scal!(n, DA, blasbuffer(DX), incx) |
38 |
| - DX |
39 |
| - end |
40 |
| - end |
41 |
| -end |
42 |
| - |
43 |
| -LinearAlgebra.rmul!(s::Number, X::AbstractGPUArray) = rmul!(X, s) |
44 |
| -function LinearAlgebra.rmul!(X::AbstractGPUArray{T}, s::Number) where T <: BLAS.BlasComplex |
45 |
| - R = typeof(real(zero(T))) |
46 |
| - N = 2*length(X) |
47 |
| - buff = unsafe_reinterpret(R, X, (N,)) |
48 |
| - BLAS.scal!(N, R(s), buff, 1) |
49 |
| - X |
50 |
| -end |
51 |
| -function LinearAlgebra.rmul!(X::AbstractGPUArray{T}, s::Number) where T <: Union{Float32, Float64} |
52 |
| - BLAS.scal!(length(X), T(s), X, 1) |
53 |
| - X |
54 |
| -end |
55 |
| - |
56 |
| -for elty in (Float32, Float64, ComplexF32, ComplexF64) |
57 |
| - T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty |
58 |
| - @eval begin |
59 |
| - function BLAS.gemv!(trans::AbstractChar, alpha::$T, A::AbstractGPUVecOrMat{$elty}, X::AbstractGPUVector{$elty}, beta::$T, Y::AbstractGPUVector{$elty}) |
60 |
| - m, n = size(A, 1), size(A, 2) |
61 |
| - if trans == 'N' && (length(X) != n || length(Y) != m) |
62 |
| - throw(DimensionMismatch("A has dimensions $(size(A)), X has length $(length(X)) and Y has length $(length(Y))")) |
63 |
| - elseif trans == 'C' && (length(X) != m || length(Y) != n) |
64 |
| - throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))")) |
65 |
| - elseif trans == 'T' && (length(X) != m || length(Y) != n) |
66 |
| - throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))")) |
67 |
| - end |
68 |
| - blasmod = blas_module(A) |
69 |
| - blasmod.gemv!( |
70 |
| - trans, alpha, |
71 |
| - blasbuffer(A), blasbuffer(X), beta, blasbuffer(Y) |
72 |
| - ) |
73 |
| - Y |
74 |
| - end |
75 |
| - end |
76 |
| -end |
77 |
| - |
78 |
| -for elty in (Float32, Float64, ComplexF32, ComplexF64) |
79 |
| - @eval begin |
80 |
| - function BLAS.axpy!( |
81 |
| - alpha::Number, x::AbstractGPUArray{$elty}, y::AbstractGPUArray{$elty} |
82 |
| - ) |
83 |
| - if length(x) != length(y) |
84 |
| - throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))")) |
85 |
| - end |
86 |
| - blasmod = blas_module(x) |
87 |
| - blasmod.axpy!($elty(alpha), blasbuffer(vec(x)), blasbuffer(vec(y))) |
88 |
| - y |
89 |
| - end |
90 |
| - end |
91 |
| -end |
92 |
| - |
93 |
| -for elty in (Float32, Float64, ComplexF32, ComplexF64) |
94 |
| - @eval begin |
95 |
| - function BLAS.gbmv!(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::AbstractGPUMatrix{$elty}, X::AbstractGPUVector{$elty}, beta::($elty), Y::AbstractGPUVector{$elty}) |
96 |
| - n = size(A, 2) |
97 |
| - if trans == 'N' && (length(X) != n || length(Y) != m) |
98 |
| - throw(DimensionMismatch("A has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))")) |
99 |
| - elseif trans == 'C' && (length(X) != m || length(Y) != n) |
100 |
| - throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))")) |
101 |
| - elseif trans == 'T' && (length(X) != m || length(Y) != n) |
102 |
| - throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))")) |
103 |
| - end |
104 |
| - blasmod = blas_module(A) |
105 |
| - blasmod.gbmv!( |
106 |
| - trans, m, kl, ku, alpha, |
107 |
| - blasbuffer(A), blasbuffer(X), beta, blasbuffer(Y) |
108 |
| - ) |
109 |
| - Y |
110 |
| - end |
111 |
| - end |
112 |
| -end |
113 |
| - |
114 |
| - |
115 |
| -## high-level functionality |
116 |
| - |
117 | 3 | function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
|
118 | 4 | gpu_call(At, A) do ctx, At, A
|
119 | 5 | idx = @cartesianidx A ctx
|
|
0 commit comments