@@ -78,15 +78,22 @@ function juliaStorageType(T::Type{<:Complex}, ct::cublasComputeType_t)
78
78
end
79
79
80
80
# Level 1
81
+
82
+ # most level 1 routines are intended for use on vectors, so only accept a single stride.
83
+ # however, it is often convenient to also use these routines on arbitrary arrays,
84
+ # interpreting them as vectors. this does not work with arbitrary strides, so we
85
+ # define a union matching arrays with only a non-unit stride in the first dimension.
86
+ const StridedCuVecOrDenseMat{T} = Union{StridedCuVector{T}, DenseCuArray{T}}
87
+
81
88
# # copy
82
89
for (fname, fname_64, elty) in ((:cublasDcopy_v2 , :cublasDcopy_v2_64 , :Float64 ),
83
90
(:cublasScopy_v2 , :cublasScopy_v2_64 , :Float32 ),
84
91
(:cublasZcopy_v2 , :cublasZcopy_v2_64 , :ComplexF64 ),
85
92
(:cublasCcopy_v2 , :cublasCcopy_v2_64 , :ComplexF32 ))
86
93
@eval begin
87
94
function copy! (n:: Integer ,
88
- x:: StridedCuArray {$elty} ,
89
- y:: StridedCuArray {$elty} ,)
95
+ x:: StridedCuVecOrDenseMat {$elty} ,
96
+ y:: StridedCuVecOrDenseMat {$elty} ,)
90
97
if CUBLAS. version () >= v " 12.0"
91
98
$ fname_64 (handle (), n, x, stride (x, 1 ), y, stride (y, 1 ))
92
99
else
@@ -96,7 +103,8 @@ for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64),
96
103
end
97
104
end
98
105
end
99
- function copy! (n:: Integer , x:: StridedCuArray{T} , y:: StridedCuArray{T} ) where {T <: Union{Float16, ComplexF16} }
106
+ function copy! (n:: Integer , x:: StridedCuVecOrDenseMat{T} ,
107
+ y:: StridedCuVecOrDenseMat{T} ) where {T <: Union{Float16, ComplexF16} }
100
108
copyto! (y, x) # bad
101
109
end
102
110
@@ -108,7 +116,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
108
116
@eval begin
109
117
function scal! (n:: Integer ,
110
118
alpha:: Number ,
111
- x:: StridedCuArray {$elty} )
119
+ x:: StridedCuVecOrDenseMat {$elty} )
112
120
if CUBLAS. version () >= v " 12.0"
113
121
$ fname_64 (handle (), n, alpha, x, stride (x, 1 ))
114
122
else
@@ -118,7 +126,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
118
126
end
119
127
end
120
128
end
121
- function scal! (n:: Integer , alpha:: Number , x:: StridedCuArray {Float16} )
129
+ function scal! (n:: Integer , alpha:: Number , x:: StridedCuVecOrDenseMat {Float16} )
122
130
α = convert (Float32, alpha)
123
131
cublasScalEx (handle (), n, Ref {Float32} (α), Float32, x, Float16, stride (x, 1 ), Float32)
124
132
return x
@@ -129,7 +137,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
129
137
@eval begin
130
138
function scal! (n:: Integer ,
131
139
alpha:: $elty ,
132
- x:: StridedCuArray {$celty} )
140
+ x:: StridedCuVecOrDenseMat {$celty} )
133
141
if CUBLAS. version () >= v " 12.0"
134
142
$ fname_64 (handle (), n, alpha, x, stride (x, 1 ))
135
143
else
@@ -139,7 +147,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
139
147
end
140
148
end
141
149
end
142
- function scal! (n:: Integer , alpha:: Number , x:: StridedCuArray {ComplexF16} )
150
+ function scal! (n:: Integer , alpha:: Number , x:: StridedCuVecOrDenseMat {ComplexF16} )
143
151
wide_x = widen .(x)
144
152
scal! (n, alpha, wide_x)
145
153
thin_x = convert (typeof (x), wide_x)
@@ -156,8 +164,8 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
156
164
(:dotu , :cublasCdotu_v2 , :cublasCdotu_v2_64 , :ComplexF32 ))
157
165
@eval begin
158
166
function $jname (n:: Integer ,
159
- x:: StridedCuArray {$elty} ,
160
- y:: StridedCuArray {$elty} )
167
+ x:: StridedCuVecOrDenseMat {$elty} ,
168
+ y:: StridedCuVecOrDenseMat {$elty} )
161
169
result = Ref {$elty} ()
162
170
if CUBLAS. version () >= v " 12.0"
163
171
$ fname_64 (handle (), n, x, stride (x, 1 ), y, stride (y, 1 ), result)
@@ -168,15 +176,15 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
168
176
end
169
177
end
170
178
end
171
- function dot (n:: Integer , x:: StridedCuArray {Float16} , y:: StridedCuArray {Float16} )
179
+ function dot (n:: Integer , x:: StridedCuVecOrDenseMat {Float16} , y:: StridedCuVecOrDenseMat {Float16} )
172
180
result = Ref {Float16} ()
173
181
cublasDotEx (handle (), n, x, Float16, stride (x, 1 ), y, Float16, stride (y, 1 ), result, Float16, Float32)
174
182
return result[]
175
183
end
176
- function dotc (n:: Integer , x:: StridedCuArray {ComplexF16} , y:: StridedCuArray {ComplexF16} )
184
+ function dotc (n:: Integer , x:: StridedCuVecOrDenseMat {ComplexF16} , y:: StridedCuVecOrDenseMat {ComplexF16} )
177
185
convert (ComplexF16, dotc (n, convert (CuArray{ComplexF32}, x), convert (CuArray{ComplexF32}, y)))
178
186
end
179
- function dotu (n:: Integer , x:: StridedCuArray {ComplexF16} , y:: DenseCuArray {ComplexF16} )
187
+ function dotu (n:: Integer , x:: StridedCuVecOrDenseMat {ComplexF16} , y:: StridedCuVecOrDenseMat {ComplexF16} )
180
188
convert (ComplexF16, dotu (n, convert (CuArray{ComplexF32}, x), convert (CuArray{ComplexF32}, y)))
181
189
end
182
190
@@ -187,7 +195,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64,
187
195
(:cublasScnrm2_v2 , :cublasScnrm2_v2_64 , :ComplexF32 , :Float32 ))
188
196
@eval begin
189
197
function nrm2 (n:: Integer ,
190
- X:: StridedCuArray {$elty} )
198
+ X:: StridedCuVecOrDenseMat {$elty} )
191
199
result = Ref {$ret_type} ()
192
200
if CUBLAS. version () >= v " 12.0"
193
201
$ fname_64 (handle (), n, X, stride (X, 1 ), result)
@@ -198,14 +206,14 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64,
198
206
end
199
207
end
200
208
end
201
- nrm2 (x:: StridedCuArray ) = nrm2 (length (x), x)
209
+ nrm2 (x:: StridedCuVecOrDenseMat ) = nrm2 (length (x), x)
202
210
203
- function nrm2 (n:: Integer , x:: StridedCuArray {Float16} )
211
+ function nrm2 (n:: Integer , x:: StridedCuVecOrDenseMat {Float16} )
204
212
result = Ref {Float16} ()
205
213
cublasNrm2Ex (handle (), n, x, Float16, stride (x, 1 ), result, Float16, Float32)
206
214
return result[]
207
215
end
208
- function nrm2 (n:: Integer , x:: StridedCuArray {ComplexF16} )
216
+ function nrm2 (n:: Integer , x:: StridedCuVecOrDenseMat {ComplexF16} )
209
217
wide_x = widen .(x)
210
218
nrm = nrm2 (n, wide_x)
211
219
return convert (Float16, nrm)
@@ -218,7 +226,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDasum_v2, :cublasDasum_v2_64,
218
226
(:cublasScasum_v2 , :cublasScasum_v2_64 , :ComplexF32 , :Float32 ))
219
227
@eval begin
220
228
function asum (n:: Integer ,
221
- x:: StridedCuArray {$elty} )
229
+ x:: StridedCuVecOrDenseMat {$elty} )
222
230
result = Ref {$ret_type} ()
223
231
if CUBLAS. version () >= v " 12.0"
224
232
$ fname_64 (handle (), n, x, stride (x, 1 ), result)
@@ -238,8 +246,8 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
238
246
@eval begin
239
247
function axpy! (n:: Integer ,
240
248
alpha:: Number ,
241
- dx:: StridedCuArray {$elty} ,
242
- dy:: StridedCuArray {$elty} )
249
+ dx:: StridedCuVecOrDenseMat {$elty} ,
250
+ dy:: StridedCuVecOrDenseMat {$elty} )
243
251
if CUBLAS. version () >= v " 12.0"
244
252
$ fname_64 (handle (), n, alpha, dx, stride (dx, 1 ), dy, stride (dy, 1 ))
245
253
else
@@ -250,12 +258,12 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
250
258
end
251
259
end
252
260
253
- function axpy! (n:: Integer , alpha:: Number , dx:: StridedCuArray {Float16} , dy:: StridedCuArray {Float16} )
261
+ function axpy! (n:: Integer , alpha:: Number , dx:: StridedCuVecOrDenseMat {Float16} , dy:: StridedCuVecOrDenseMat {Float16} )
254
262
α = convert (Float32, alpha)
255
263
cublasAxpyEx (handle (), n, Ref {Float32} (α), Float32, dx, Float16, stride (dx, 1 ), dy, Float16, stride (dy, 1 ), Float32)
256
264
return dy
257
265
end
258
- function axpy! (n:: Integer , alpha:: Number , dx:: StridedCuArray {ComplexF16} , dy:: StridedCuArray {ComplexF16} )
266
+ function axpy! (n:: Integer , alpha:: Number , dx:: StridedCuVecOrDenseMat {ComplexF16} , dy:: StridedCuVecOrDenseMat {ComplexF16} )
259
267
wide_x = widen .(dx)
260
268
wide_y = widen .(dy)
261
269
axpy! (n, alpha, wide_x, wide_y)
@@ -273,8 +281,8 @@ for (fname, fname_64, elty, sty) in ((:cublasSrot_v2, :cublasSrot_v2_64, :Float3
273
281
(:cublasZdrot_v2 , :cublasZdrot_v2_64 , :ComplexF64 , :Real ))
274
282
@eval begin
275
283
function rot! (n:: Integer ,
276
- x:: StridedCuArray {$elty} ,
277
- y:: StridedCuArray {$elty} ,
284
+ x:: StridedCuVecOrDenseMat {$elty} ,
285
+ y:: StridedCuVecOrDenseMat {$elty} ,
278
286
c:: Real ,
279
287
s:: $sty )
280
288
if CUBLAS. version () >= v " 12.0"
@@ -294,8 +302,8 @@ for (fname, fname_64, elty) in ((:cublasSswap_v2, :cublasSswap_v2_64, :Float32),
294
302
(:cublasZswap_v2 , :cublasZswap_v2_64 , :ComplexF64 ))
295
303
@eval begin
296
304
function swap! (n:: Integer ,
297
- x:: StridedCuArray {$elty} ,
298
- y:: StridedCuArray {$elty} )
305
+ x:: StridedCuVecOrDenseMat {$elty} ,
306
+ y:: StridedCuVecOrDenseMat {$elty} )
299
307
if CUBLAS. version () >= v " 12.0"
300
308
$ fname_64 (handle (), n, x, stride (x, 1 ), y, stride (y, 1 ))
301
309
else
308
316
309
317
function axpby! (n:: Integer ,
310
318
alpha:: Number ,
311
- dx:: StridedCuArray {T} ,
319
+ dx:: StridedCuVecOrDenseMat {T} ,
312
320
beta:: Number ,
313
- dy:: StridedCuArray {T} ) where T <: Union{Float16, ComplexF16, CublasFloat}
321
+ dy:: StridedCuVecOrDenseMat {T} ) where T <: Union{Float16, ComplexF16, CublasFloat}
314
322
scal! (n, beta, dy)
315
323
axpy! (n, alpha, dx, dy)
316
324
dy
@@ -324,7 +332,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
324
332
(:cublasIcamax_v2 , :cublasIcamax_v2_64 , :ComplexF32 ))
325
333
@eval begin
326
334
function iamax (n:: Integer ,
327
- dx:: StridedCuArray {$elty} )
335
+ dx:: StridedCuVecOrDenseMat {$elty} )
328
336
if CUBLAS. version () >= v " 12.0"
329
337
result = Ref {Int64} ()
330
338
$ fname_64 (handle (), n, dx, stride (dx, 1 ), result)
@@ -336,7 +344,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
336
344
end
337
345
end
338
346
end
339
- iamax (dx:: StridedCuArray ) = iamax (length (dx), dx)
347
+ iamax (dx:: StridedCuVecOrDenseMat ) = iamax (length (dx), dx)
340
348
341
349
# # iamin
342
350
# iamin is not in standard blas is a CUBLAS extension
@@ -346,7 +354,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
346
354
(:cublasIcamin_v2 , :cublasIcamin_v2_64 , :ComplexF32 ))
347
355
@eval begin
348
356
function iamin (n:: Integer ,
349
- dx:: StridedCuArray {$elty} ,)
357
+ dx:: StridedCuVecOrDenseMat {$elty} ,)
350
358
if CUBLAS. version () >= v " 12.0"
351
359
result = Ref {Int64} ()
352
360
$ fname_64 (handle (), n, dx, stride (dx, 1 ), result)
@@ -358,7 +366,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
358
366
end
359
367
end
360
368
end
361
- iamin (dx:: StridedCuArray ) = iamin (length (dx), dx)
369
+ iamin (dx:: StridedCuVecOrDenseMat ) = iamin (length (dx), dx)
362
370
363
371
# Level 2
364
372
# # mv
0 commit comments