@@ -268,7 +268,7 @@ for (bname, fname, elty) in ((:onemklSorgqr_scratchpad_size, :onemklSorgqr, :Flo
268
268
end
269
269
end
270
270
271
- # gebrd
271
+ # gebrd
272
272
for (bname, fname, elty, relty) in ((:onemklSgebrd_scratchpad_size , :onemklSgebrd , :Float32 , :Float32 ),
273
273
(:onemklDgebrd_scratchpad_size , :onemklDgebrd , :Float64 , :Float64 ),
274
274
(:onemklCgebrd_scratchpad_size , :onemklCgebrd , :ComplexF32 , :Float32 ),
@@ -280,7 +280,7 @@ for (bname, fname, elty, relty) in ((:onemklSgebrd_scratchpad_size, :onemklSgebr
280
280
281
281
k = min (m, n)
282
282
D = oneVector {$relty} (undef, k)
283
- E = oneVector {$elty } (undef, k)
283
+ E = oneVector {$relty } (undef, k- 1 )
284
284
tauq = oneVector {$elty} (undef, k)
285
285
taup = oneVector {$elty} (undef, k)
286
286
@@ -294,6 +294,114 @@ for (bname, fname, elty, relty) in ((:onemklSgebrd_scratchpad_size, :onemklSgebr
294
294
end
295
295
end
296
296
297
+ # gesvd
298
+ for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size , :onemklSgesvd , :Float32 , :Float32 ),
299
+ (:onemklDgesvd_scratchpad_size , :onemklDgesvd , :Float64 , :Float64 ),
300
+ (:onemklCgesvd_scratchpad_size , :onemklCgesvd , :ComplexF32 , :Float32 ),
301
+ (:onemklZgesvd_scratchpad_size , :onemklZgesvd , :ComplexF64 , :Float64 ))
302
+ @eval begin
303
+ function gesvd! (jobu:: Char ,
304
+ jobvt:: Char ,
305
+ A:: oneStridedMatrix{$elty} )
306
+ m, n = size (A)
307
+ lda = max (1 , stride (A, 2 ))
308
+
309
+ U = if jobu === ' A'
310
+ oneMatrix {$elty} (undef, m, m)
311
+ elseif jobu == ' S' || jobu === ' O'
312
+ oneMatrix {$elty} (undef, m, min (m, n))
313
+ elseif jobu === ' N'
314
+ oneMatrix {$elty} (undef, 0 , 0 ) # Equivalence of CU_NULL?
315
+ else
316
+ error (" jobu must be one of 'A', 'S', 'O', or 'N'" )
317
+ end
318
+ ldu = U == oneMatrix {$elty} (undef, 0 , 0 ) ? 1 : max (1 , stride (U, 2 ))
319
+ S = oneVector {$relty} (undef, min (m, n))
320
+
321
+ Vt = if jobvt === ' A'
322
+ oneMatrix {$elty} (undef, n, n)
323
+ elseif jobvt === ' S' || jobvt === ' O'
324
+ oneMatrix {$elty} (undef, min (m, n), n)
325
+ elseif jobvt === ' N'
326
+ oneMatrix {$elty} (undef, 0 , 0 )
327
+ else
328
+ error (" jobvt must be one of 'A', 'S', 'O', or 'N'" )
329
+ end
330
+ ldvt = Vt == oneArray {$elty} (undef, 0 , 0 ) ? 1 : max (1 , stride (Vt, 2 ))
331
+
332
+ queue = global_queue (context (A), device (A))
333
+ scratchpad_size = $ bname (sycl_queue (queue), jobu, jobvt, m, n, lda, ldu, ldvt)
334
+ scratchpad = oneVector {$elty} (undef, scratchpad_size)
335
+ $ fname (sycl_queue (queue), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt, scratchpad, scratchpad_size)
336
+
337
+ return U, S, Vt
338
+ end
339
+ end
340
+ end
341
+
342
+ # syevd and heevd
343
+ for (jname, bname, fname, elty, relty) in ((:syevd! , :onemklSsyevd_scratchpad_size , :onemklSsyevd , :Float32 , :Float32 ),
344
+ (:syevd! , :onemklDsyevd_scratchpad_size , :onemklDsyevd , :Float64 , :Float64 ),
345
+ (:heevd! , :onemklCheevd_scratchpad_size , :onemklCheevd , :ComplexF32 , :Float32 ),
346
+ (:heevd! , :onemklZheevd_scratchpad_size , :onemklZheevd , :ComplexF64 , :Float64 ))
347
+ @eval begin
348
+ function $jname (jobz:: Char ,
349
+ uplo:: Char ,
350
+ A:: oneStridedMatrix{$elty} )
351
+ chkuplo (uplo)
352
+ n = checksquare (A)
353
+ lda = max (1 , stride (A, 2 ))
354
+ W = oneVector {$relty} (undef, n)
355
+
356
+ queue = global_queue (context (A), device (A))
357
+ scratchpad_size = $ bname (sycl_queue (queue), jobz, uplo, n, lda)
358
+ scratchpad = oneVector {$elty} (undef, scratchpad_size)
359
+ $ fname (sycl_queue (queue), jobz, uplo, n, A, lda, W, scratchpad, scratchpad_size)
360
+
361
+ if jobz == ' N'
362
+ return W
363
+ elseif jobz == ' V'
364
+ return W, A
365
+ end
366
+ end
367
+ end
368
+ end
369
+
370
+ # sygvd and hegvd
371
+ for (jname, bname, fname, elty, relty) in ((:sygvd! , :onemklSsygvd_scratchpad_size , :onemklSsygvd , :Float32 , :Float32 ),
372
+ (:sygvd! , :onemklDsygvd_scratchpad_size , :onemklDsygvd , :Float64 , :Float64 ),
373
+ (:hegvd! , :onemklChegvd_scratchpad_size , :onemklChegvd , :ComplexF32 , :Float32 ),
374
+ (:hegvd! , :onemklZhegvd_scratchpad_size , :onemklZhegvd , :ComplexF64 , :Float64 ))
375
+ @eval begin
376
+ function $jname (itype:: Int ,
377
+ jobz:: Char ,
378
+ uplo:: Char ,
379
+ A:: oneStridedMatrix{$elty} ,
380
+ B:: oneStridedMatrix{$elty} )
381
+ chkuplo (uplo)
382
+ nA, nB = checksquare (A, B)
383
+ if nB != nA
384
+ throw (DimensionMismatch (" Dimensions of A ($nA , $nA ) and B ($nB , $nB ) must match!" ))
385
+ end
386
+ n = nA
387
+ lda = max (1 , stride (A, 2 ))
388
+ ldb = max (1 , stride (B, 2 ))
389
+ W = oneVector {$relty} (undef, n)
390
+
391
+ queue = global_queue (context (A), device (A))
392
+ scratchpad_size = $ bname (sycl_queue (queue), itype, jobz, uplo, n, lda, ldb)
393
+ scratchpad = oneVector {$elty} (undef, scratchpad_size)
394
+ $ fname (sycl_queue (queue), itype, jobz, uplo, n, A, lda, B, ldb, W, scratchpad, scratchpad_size)
395
+
396
+ if jobz == ' N'
397
+ return W
398
+ elseif jobz == ' V'
399
+ return W, A, B
400
+ end
401
+ end
402
+ end
403
+ end
404
+
297
405
# getrf_batch
298
406
for (bname, fname, elty) in ((:onemklSgetrf_batch_scratchpad_size , :onemklSgetrf_batch , :Float32 ),
299
407
(:onemklDgetrf_batch_scratchpad_size , :onemklDgetrf_batch , :Float64 ),
@@ -364,5 +472,35 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
364
472
LinearAlgebra. LAPACK. getrs! (trans:: Char , A:: oneStridedMatrix{$elty} , ipiv:: oneStridedVector{Int64} , B:: oneStridedVecOrMat{$elty} ) = oneMKL. getrs! (trans, A, ipiv, B)
365
473
LinearAlgebra. LAPACK. ormqr! (side:: Char , trans:: Char , A:: oneStridedMatrix{$elty} , tau:: oneStridedVector{$elty} , C:: oneStridedVecOrMat{$elty} ) = oneMKL. ormqr! (side, trans, A, tau, C)
366
474
LinearAlgebra. LAPACK. orgqr! (A:: oneStridedMatrix{$elty} , tau:: oneStridedVector{$elty} ) = oneMKL. orgqr! (A, tau)
475
+ LinearAlgebra. LAPACK. gebrd! (A:: oneStridedMatrix{$elty} ) = oneMKL. gebrd! (A)
476
+ LinearAlgebra. LAPACK. gesvd! (jobu:: Char , jobvt:: Char , A:: oneStridedMatrix{$elty} ) = oneMKL. gesvd! (jobu, jobvt, A)
477
+ end
478
+ end
479
+
480
+ for elty in (:Float32 , :Float64 )
481
+ @eval begin
482
+ LinearAlgebra. LAPACK. syev! (jobz:: Char , uplo:: Char , A:: oneStridedMatrix{$elty} ) = oneMKL. syevd! (jobz, uplo, A)
483
+ LinearAlgebra. LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: oneStridedMatrix{$elty} , B:: oneStridedMatrix{$elty} ) = oneMKL. sygvd! (itype, jobz, uplo, A, B)
484
+ end
485
+ end
486
+
487
+ for elty in (:ComplexF32 , :ComplexF64 )
488
+ @eval begin
489
+ LinearAlgebra. LAPACK. syev! (jobz:: Char , uplo:: Char , A:: oneStridedMatrix{$elty} ) = oneMKL. heevd! (jobz, uplo, A)
490
+ LinearAlgebra. LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: oneStridedMatrix{$elty} , B:: oneStridedMatrix{$elty} ) = oneMKL. hegvd! (itype, jobz, uplo, A, B)
491
+ end
492
+ end
493
+
494
+ if VERSION >= v " 1.10"
495
+ for elty in (:Float32 , :Float64 )
496
+ @eval begin
497
+ LinearAlgebra. LAPACK. syevd! (jobz:: Char , uplo:: Char , A:: oneStridedMatrix{$elty} ) = oneMKL. syevd! (jobz, uplo, A)
498
+ end
499
+ end
500
+
501
+ for elty in (:ComplexF32 , :ComplexF64 )
502
+ @eval begin
503
+ LinearAlgebra. LAPACK. syevd! (jobz:: Char , uplo:: Char , A:: oneStridedMatrix{$elty} ) = oneMKL. heevd! (jobz, uplo, A)
504
+ end
367
505
end
368
506
end
0 commit comments