Skip to content

Commit 14feb76

Browse files
authored
CUBLAS: Don't use BLAS1 wrappers for strided arrays, only vectors. (#2528)
1 parent 50b953c commit 14feb76

File tree

3 files changed

+56
-36
lines changed

3 files changed

+56
-36
lines changed

lib/cublas/linalg.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,24 @@ LinearAlgebra.rmul!(x::StridedCuArray{<:CublasFloat}, k::Number) =
1313
LinearAlgebra.rmul!(x::DenseCuArray{<:CublasFloat}, k::Real) =
1414
invoke(rmul!, Tuple{typeof(x), Number}, x, k)
1515

16-
function LinearAlgebra.dot(x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, CublasReal}
16+
function LinearAlgebra.dot(x::StridedCuVector{T},
17+
y::StridedCuVector{T}) where T<:Union{Float16, CublasReal}
1718
n = length(x)
1819
n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
1920
dot(n, x, y)
2021
end
2122

22-
function LinearAlgebra.dot(x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{ComplexF16, CublasComplex}
23+
function LinearAlgebra.dot(x::StridedCuVector{T},
24+
y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex}
2325
n = length(x)
2426
n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
2527
dotc(n, x, y)
2628
end
2729

30+
# resolve ambiguities with generic wrapper below
31+
LinearAlgebra.dot(x::CuArray{T}, y::CuArray{T}) where T<:Union{Float32, Float64} =
32+
invoke(LinearAlgebra.dot, Tuple{StridedCuArray{T}, StridedCuArray{T}}, x, y)
33+
2834
# generic fallback
2935
function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2}
3036
n = length(x)
@@ -97,14 +103,16 @@ function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2}
97103
end
98104
end
99105

100-
function LinearAlgebra.:(*)(transx::Transpose{<:Any,<:StridedCuVector{T}}, y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex}
106+
function LinearAlgebra.:(*)(transx::Transpose{<:Any,<:StridedCuVector{T}},
107+
y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex}
101108
x = transx.parent
102109
n = length(x)
103110
n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
104111
return dotu(n, x, y)
105112
end
106113

107-
function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasFloat}}, p::Real=2)
114+
function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasFloat}},
115+
p::Real=2)
108116
if p == 2
109117
return nrm2(x)
110118
else

lib/cublas/wrappers.jl

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,22 @@ function juliaStorageType(T::Type{<:Complex}, ct::cublasComputeType_t)
7878
end
7979

8080
# Level 1
81+
82+
# most level 1 routines are intended for use on vectors, so only accept a single stride.
83+
# however, it is often convenient to also use these routines on arbitrary arrays,
84+
# interpreting them as vectors. this does not work with arbitrary strides, so we
85+
# define a union matching arrays with only a non-unit stride in the first dimension.
86+
const StridedCuVecOrDenseMat{T} = Union{StridedCuVector{T}, DenseCuArray{T}}
87+
8188
## copy
8289
for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64),
8390
(:cublasScopy_v2, :cublasScopy_v2_64, :Float32),
8491
(:cublasZcopy_v2, :cublasZcopy_v2_64, :ComplexF64),
8592
(:cublasCcopy_v2, :cublasCcopy_v2_64, :ComplexF32))
8693
@eval begin
8794
function copy!(n::Integer,
88-
x::StridedCuArray{$elty},
89-
y::StridedCuArray{$elty},)
95+
x::StridedCuVecOrDenseMat{$elty},
96+
y::StridedCuVecOrDenseMat{$elty},)
9097
if CUBLAS.version() >= v"12.0"
9198
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1))
9299
else
@@ -96,7 +103,8 @@ for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64),
96103
end
97104
end
98105
end
99-
function copy!(n::Integer, x::StridedCuArray{T}, y::StridedCuArray{T}) where {T <: Union{Float16, ComplexF16}}
106+
function copy!(n::Integer, x::StridedCuVecOrDenseMat{T},
107+
y::StridedCuVecOrDenseMat{T}) where {T <: Union{Float16, ComplexF16}}
100108
copyto!(y, x) # bad
101109
end
102110

@@ -108,7 +116,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
108116
@eval begin
109117
function scal!(n::Integer,
110118
alpha::Number,
111-
x::StridedCuArray{$elty})
119+
x::StridedCuVecOrDenseMat{$elty})
112120
if CUBLAS.version() >= v"12.0"
113121
$fname_64(handle(), n, alpha, x, stride(x, 1))
114122
else
@@ -118,7 +126,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
118126
end
119127
end
120128
end
121-
function scal!(n::Integer, alpha::Number, x::StridedCuArray{Float16})
129+
function scal!(n::Integer, alpha::Number, x::StridedCuVecOrDenseMat{Float16})
122130
α = convert(Float32, alpha)
123131
cublasScalEx(handle(), n, Ref{Float32}(α), Float32, x, Float16, stride(x, 1), Float32)
124132
return x
@@ -129,7 +137,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
129137
@eval begin
130138
function scal!(n::Integer,
131139
alpha::$elty,
132-
x::StridedCuArray{$celty})
140+
x::StridedCuVecOrDenseMat{$celty})
133141
if CUBLAS.version() >= v"12.0"
134142
$fname_64(handle(), n, alpha, x, stride(x, 1))
135143
else
@@ -139,7 +147,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
139147
end
140148
end
141149
end
142-
function scal!(n::Integer, alpha::Number, x::StridedCuArray{ComplexF16})
150+
function scal!(n::Integer, alpha::Number, x::StridedCuVecOrDenseMat{ComplexF16})
143151
wide_x = widen.(x)
144152
scal!(n, alpha, wide_x)
145153
thin_x = convert(typeof(x), wide_x)
@@ -156,8 +164,8 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
156164
(:dotu, :cublasCdotu_v2, :cublasCdotu_v2_64, :ComplexF32))
157165
@eval begin
158166
function $jname(n::Integer,
159-
x::StridedCuArray{$elty},
160-
y::StridedCuArray{$elty})
167+
x::StridedCuVecOrDenseMat{$elty},
168+
y::StridedCuVecOrDenseMat{$elty})
161169
result = Ref{$elty}()
162170
if CUBLAS.version() >= v"12.0"
163171
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), result)
@@ -168,15 +176,15 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
168176
end
169177
end
170178
end
171-
function dot(n::Integer, x::StridedCuArray{Float16}, y::StridedCuArray{Float16})
179+
function dot(n::Integer, x::StridedCuVecOrDenseMat{Float16}, y::StridedCuVecOrDenseMat{Float16})
172180
result = Ref{Float16}()
173181
cublasDotEx(handle(), n, x, Float16, stride(x, 1), y, Float16, stride(y, 1), result, Float16, Float32)
174182
return result[]
175183
end
176-
function dotc(n::Integer, x::StridedCuArray{ComplexF16}, y::StridedCuArray{ComplexF16})
184+
function dotc(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16}, y::StridedCuVecOrDenseMat{ComplexF16})
177185
convert(ComplexF16, dotc(n, convert(CuArray{ComplexF32}, x), convert(CuArray{ComplexF32}, y)))
178186
end
179-
function dotu(n::Integer, x::StridedCuArray{ComplexF16}, y::DenseCuArray{ComplexF16})
187+
function dotu(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16}, y::StridedCuVecOrDenseMat{ComplexF16})
180188
convert(ComplexF16, dotu(n, convert(CuArray{ComplexF32}, x), convert(CuArray{ComplexF32}, y)))
181189
end
182190

@@ -187,7 +195,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64,
187195
(:cublasScnrm2_v2, :cublasScnrm2_v2_64, :ComplexF32, :Float32))
188196
@eval begin
189197
function nrm2(n::Integer,
190-
X::StridedCuArray{$elty})
198+
X::StridedCuVecOrDenseMat{$elty})
191199
result = Ref{$ret_type}()
192200
if CUBLAS.version() >= v"12.0"
193201
$fname_64(handle(), n, X, stride(X, 1), result)
@@ -198,14 +206,14 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64,
198206
end
199207
end
200208
end
201-
nrm2(x::StridedCuArray) = nrm2(length(x), x)
209+
nrm2(x::StridedCuVecOrDenseMat) = nrm2(length(x), x)
202210

203-
function nrm2(n::Integer, x::StridedCuArray{Float16})
211+
function nrm2(n::Integer, x::StridedCuVecOrDenseMat{Float16})
204212
result = Ref{Float16}()
205213
cublasNrm2Ex(handle(), n, x, Float16, stride(x, 1), result, Float16, Float32)
206214
return result[]
207215
end
208-
function nrm2(n::Integer, x::StridedCuArray{ComplexF16})
216+
function nrm2(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16})
209217
wide_x = widen.(x)
210218
nrm = nrm2(n, wide_x)
211219
return convert(Float16, nrm)
@@ -218,7 +226,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDasum_v2, :cublasDasum_v2_64,
218226
(:cublasScasum_v2, :cublasScasum_v2_64, :ComplexF32, :Float32))
219227
@eval begin
220228
function asum(n::Integer,
221-
x::StridedCuArray{$elty})
229+
x::StridedCuVecOrDenseMat{$elty})
222230
result = Ref{$ret_type}()
223231
if CUBLAS.version() >= v"12.0"
224232
$fname_64(handle(), n, x, stride(x, 1), result)
@@ -238,8 +246,8 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
238246
@eval begin
239247
function axpy!(n::Integer,
240248
alpha::Number,
241-
dx::StridedCuArray{$elty},
242-
dy::StridedCuArray{$elty})
249+
dx::StridedCuVecOrDenseMat{$elty},
250+
dy::StridedCuVecOrDenseMat{$elty})
243251
if CUBLAS.version() >= v"12.0"
244252
$fname_64(handle(), n, alpha, dx, stride(dx, 1), dy, stride(dy, 1))
245253
else
@@ -250,12 +258,12 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
250258
end
251259
end
252260

253-
function axpy!(n::Integer, alpha::Number, dx::StridedCuArray{Float16}, dy::StridedCuArray{Float16})
261+
function axpy!(n::Integer, alpha::Number, dx::StridedCuVecOrDenseMat{Float16}, dy::StridedCuVecOrDenseMat{Float16})
254262
α = convert(Float32, alpha)
255263
cublasAxpyEx(handle(), n, Ref{Float32}(α), Float32, dx, Float16, stride(dx, 1), dy, Float16, stride(dy, 1), Float32)
256264
return dy
257265
end
258-
function axpy!(n::Integer, alpha::Number, dx::StridedCuArray{ComplexF16}, dy::StridedCuArray{ComplexF16})
266+
function axpy!(n::Integer, alpha::Number, dx::StridedCuVecOrDenseMat{ComplexF16}, dy::StridedCuVecOrDenseMat{ComplexF16})
259267
wide_x = widen.(dx)
260268
wide_y = widen.(dy)
261269
axpy!(n, alpha, wide_x, wide_y)
@@ -273,8 +281,8 @@ for (fname, fname_64, elty, sty) in ((:cublasSrot_v2, :cublasSrot_v2_64, :Float3
273281
(:cublasZdrot_v2, :cublasZdrot_v2_64, :ComplexF64, :Real))
274282
@eval begin
275283
function rot!(n::Integer,
276-
x::StridedCuArray{$elty},
277-
y::StridedCuArray{$elty},
284+
x::StridedCuVecOrDenseMat{$elty},
285+
y::StridedCuVecOrDenseMat{$elty},
278286
c::Real,
279287
s::$sty)
280288
if CUBLAS.version() >= v"12.0"
@@ -294,8 +302,8 @@ for (fname, fname_64, elty) in ((:cublasSswap_v2, :cublasSswap_v2_64, :Float32),
294302
(:cublasZswap_v2, :cublasZswap_v2_64, :ComplexF64))
295303
@eval begin
296304
function swap!(n::Integer,
297-
x::StridedCuArray{$elty},
298-
y::StridedCuArray{$elty})
305+
x::StridedCuVecOrDenseMat{$elty},
306+
y::StridedCuVecOrDenseMat{$elty})
299307
if CUBLAS.version() >= v"12.0"
300308
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1))
301309
else
@@ -308,9 +316,9 @@ end
308316

309317
function axpby!(n::Integer,
310318
alpha::Number,
311-
dx::StridedCuArray{T},
319+
dx::StridedCuVecOrDenseMat{T},
312320
beta::Number,
313-
dy::StridedCuArray{T}) where T <: Union{Float16, ComplexF16, CublasFloat}
321+
dy::StridedCuVecOrDenseMat{T}) where T <: Union{Float16, ComplexF16, CublasFloat}
314322
scal!(n, beta, dy)
315323
axpy!(n, alpha, dx, dy)
316324
dy
@@ -324,7 +332,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
324332
(:cublasIcamax_v2, :cublasIcamax_v2_64, :ComplexF32))
325333
@eval begin
326334
function iamax(n::Integer,
327-
dx::StridedCuArray{$elty})
335+
dx::StridedCuVecOrDenseMat{$elty})
328336
if CUBLAS.version() >= v"12.0"
329337
result = Ref{Int64}()
330338
$fname_64(handle(), n, dx, stride(dx, 1), result)
@@ -336,7 +344,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
336344
end
337345
end
338346
end
339-
iamax(dx::StridedCuArray) = iamax(length(dx), dx)
347+
iamax(dx::StridedCuVecOrDenseMat) = iamax(length(dx), dx)
340348

341349
## iamin
342350
# iamin is not in standard blas is a CUBLAS extension
@@ -346,7 +354,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
346354
(:cublasIcamin_v2, :cublasIcamin_v2_64, :ComplexF32))
347355
@eval begin
348356
function iamin(n::Integer,
349-
dx::StridedCuArray{$elty},)
357+
dx::StridedCuVecOrDenseMat{$elty},)
350358
if CUBLAS.version() >= v"12.0"
351359
result = Ref{Int64}()
352360
$fname_64(handle(), n, dx, stride(dx, 1), result)
@@ -358,7 +366,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
358366
end
359367
end
360368
end
361-
iamin(dx::StridedCuArray) = iamin(length(dx), dx)
369+
iamin(dx::StridedCuVecOrDenseMat) = iamin(length(dx), dx)
362370

363371
# Level 2
364372
## mv

test/base/linalg.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ end
1616
end
1717

1818
@test testf(dot, rand(Bool, 1024, 1024), rand(Float64, 1024, 1024))
19+
20+
# https://discourse.julialang.org/t/result-of-inner-product-of-two-cuarray-with-views-is-incorrect/121539
21+
@test testf(dot, view(rand(Float32, 100, 100), 2:99, 2:99),
22+
view(rand(Float32, 100, 100), 2:99, 2:99))
1923
end
2024

2125
@testset "kron" begin
@@ -33,4 +37,4 @@ end
3337
@test Array(kron(A, B)) kron(Array(A), Array(B))
3438
@test Array(kron(B, A)) kron(Array(B), Array(A))
3539
end
36-
end
40+
end

0 commit comments

Comments
 (0)