Skip to content

Commit 1066b17

Browse files
committed
Format
1 parent e45e444 commit 1066b17

File tree

2 files changed

+69
-53
lines changed

2 files changed

+69
-53
lines changed

lib/mkl/wrappers_sparse.jl

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,20 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
147147
end
148148
end
149149

150-
# Special handling for CSC matrices since they are stored as transposed CSR
151-
for (fname, elty) in ((:onemklCsparse_gemv, :ComplexF32),
152-
(:onemklZsparse_gemv, :ComplexF64))
150+
# Special handling for CSC matrices since they are stored as transposed CSR
151+
for (fname, elty) in (
152+
(:onemklCsparse_gemv, :ComplexF32),
153+
(:onemklZsparse_gemv, :ComplexF64),
154+
)
153155
@eval begin
154-
function sparse_gemv!(trans::Char,
156+
function sparse_gemv!(
157+
trans::Char,
155158
alpha::Number,
156159
A::$SparseMatrix{$elty},
157160
x::oneStridedVector{$elty},
158161
beta::Number,
159-
y::oneStridedVector{$elty})
162+
y::oneStridedVector{$elty}
163+
)
160164

161165
# Compute A^H*x via identity:
162166
# conj(y_new) = conj(alpha) * (A^T) * conj(x) + conj(beta) * conj(y)
@@ -177,32 +181,34 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
177181
# Restore x
178182
x .= conj.(x)
179183
end
180-
y
184+
return y
181185
end
182186
end
183187
end
184188
@eval begin
185-
function sparse_optimize_gemv!(trans::Char, A::$SparseMatrix)
186-
# complex 'C' case is implemented using op='N' on S=A^T with conjugation trick
187-
queue = global_queue(context(A.nzVal), device(A.nzVal))
188-
onemklXsparse_optimize_gemv(sycl_queue(queue), flip_trans(trans), A.handle)
189-
return A
189+
function sparse_optimize_gemv!(trans::Char, A::$SparseMatrix)
190+
# complex 'C' case is implemented using op='N' on S=A^T with conjugation trick
191+
queue = global_queue(context(A.nzVal), device(A.nzVal))
192+
onemklXsparse_optimize_gemv(sycl_queue(queue), flip_trans(trans), A.handle)
193+
return A
190194
end
191195
end
192196
end
193197

194198
for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
195-
(:onemklDsparse_gemm, :Float64),
196-
(:onemklCsparse_gemm, :ComplexF32),
197-
(:onemklZsparse_gemm, :ComplexF64))
199+
(:onemklDsparse_gemm, :Float64),
200+
(:onemklCsparse_gemm, :ComplexF32),
201+
(:onemklZsparse_gemm, :ComplexF64),
202+
)
198203
@eval begin
199204
function sparse_gemm!(transa::Char,
200-
transb::Char,
201-
alpha::Number,
202-
A::oneSparseMatrixCSR{$elty},
203-
B::oneStridedMatrix{$elty},
204-
beta::Number,
205-
C::oneStridedMatrix{$elty})
205+
transb::Char,
206+
alpha::Number,
207+
A::oneSparseMatrixCSR{$elty},
208+
B::oneStridedMatrix{$elty},
209+
beta::Number,
210+
C::oneStridedMatrix{$elty}
211+
)
206212

207213
mB, nB = size(B)
208214
mC, nC = size(C)
@@ -261,13 +267,15 @@ for (fname, elty) in (
261267
(:onemklZsparse_gemm, :ComplexF64),
262268
)
263269
@eval begin
264-
function sparse_gemm!(transa::Char,
270+
function sparse_gemm!(
271+
transa::Char,
265272
transb::Char,
266273
alpha::Number,
267274
A::oneSparseMatrixCSC{$elty},
268275
B::oneStridedMatrix{$elty},
269276
beta::Number,
270-
C::oneStridedMatrix{$elty})
277+
C::oneStridedMatrix{$elty}
278+
)
271279

272280
# Map op(A) to op(S) where S = A^T stored as CSR in the handle
273281
# transa: 'N' -> op(S)='T'; 'T' -> op(S)='N'; 'C' ->
@@ -279,8 +287,8 @@ for (fname, elty) in (
279287
(nB != nC) && (transb == 'N') && throw(ArgumentError("B and C must have the same number of columns."))
280288
(mB != nC) && (transb != 'N') && throw(ArgumentError("Bᵀ and C must have the same number of columns."))
281289
nrhs = size(B, 2)
282-
ldb = max(1,stride(B,2))
283-
ldc = max(1,stride(C,2))
290+
ldb = max(1, stride(B, 2))
291+
ldc = max(1, stride(C, 2))
284292
queue = global_queue(context(C), device())
285293

286294
# Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
@@ -359,9 +367,10 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
359367
end
360368

361369
for (fname, elty) in ((:onemklSsparse_symv, :Float32),
362-
(:onemklDsparse_symv, :Float64),
363-
(:onemklCsparse_symv, :ComplexF32),
364-
(:onemklZsparse_symv, :ComplexF64))
370+
(:onemklDsparse_symv, :Float64),
371+
(:onemklCsparse_symv, :ComplexF32),
372+
(:onemklZsparse_symv, :ComplexF64),
373+
)
365374
@eval begin
366375
function sparse_symv!(uplo::Char,
367376
alpha::Number,
@@ -435,7 +444,7 @@ for (fname, elty) in (
435444
)
436445
queue = global_queue(context(y), device())
437446
$fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, beta, y)
438-
y
447+
return y
439448
end
440449
end
441450
end
@@ -444,8 +453,8 @@ function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
444453
throw(
445454
ArgumentError(
446455
"sparse_optimize_trmv! is not supported for oneSparseMatrixCSC due to Intel oneAPI limitations. " *
447-
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
448-
"Convert to oneSparseMatrixCSR format instead."
456+
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
457+
"Convert to oneSparseMatrixCSR format instead."
449458
)
450459
)
451460
queue = global_queue(context(A.nzVal), device(A.nzVal))
@@ -499,8 +508,8 @@ for (fname, elty) in (
499508
throw(
500509
ArgumentError(
501510
"sparse_trsv! is not supported for oneSparseMatrixCSC due to Intel oneAPI limitations. " *
502-
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
503-
"Convert to oneSparseMatrixCSR format instead."
511+
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
512+
"Convert to oneSparseMatrixCSR format instead."
504513
)
505514
)
506515
queue = global_queue(context(y), device())
@@ -514,8 +523,8 @@ function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
514523
throw(
515524
ArgumentError(
516525
"sparse_optimize_trsv! is not supported for oneSparseMatrixCSC due to Intel oneAPI limitations. " *
517-
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
518-
"Convert to oneSparseMatrixCSR format instead."
526+
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
527+
"Convert to oneSparseMatrixCSR format instead."
519528
)
520529
)
521530
queue = global_queue(context(A.nzVal), device(A.nzVal))
@@ -571,7 +580,8 @@ for (fname, elty) in (
571580
(:onemklSsparse_trsm, :Float32),
572581
(:onemklDsparse_trsm, :Float64),
573582
(:onemklCsparse_trsm, :ComplexF32),
574-
(:onemklZsparse_trsm, :ComplexF64))
583+
(:onemklZsparse_trsm, :ComplexF64),
584+
)
575585
@eval begin
576586
function sparse_trsm!(
577587
uplo::Char,
@@ -581,7 +591,8 @@ for (fname, elty) in (
581591
alpha::Number,
582592
A::oneSparseMatrixCSC{$elty},
583593
X::oneStridedMatrix{$elty},
584-
Y::oneStridedMatrix{$elty})
594+
Y::oneStridedMatrix{$elty}
595+
)
585596

586597
# Intel oneAPI sparse trsm only supports nontrans operations for the matrix A.
587598
# Since CSC(A) is stored as CSR(A^T), we cannot map CSC operations
@@ -601,11 +612,11 @@ for (fname, elty) in (
601612
(nX != mY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of rows."))
602613
(mX != nY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of columns."))
603614
nrhs = size(X, 2)
604-
ldx = max(1,stride(X,2))
605-
ldy = max(1,stride(Y,2))
615+
ldx = max(1, stride(X, 2))
616+
ldy = max(1, stride(Y, 2))
606617
queue = global_queue(context(Y), device())
607618
$fname(sycl_queue(queue), 'C', flip_trans(transA), transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
608-
Y
619+
return Y
609620
end
610621
end
611622
end
@@ -614,8 +625,8 @@ function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, A::oneSparse
614625
throw(
615626
ArgumentError(
616627
"sparse_optimize_trsm! is not supported for oneSparseMatrixCSC due to Intel oneAPI limitations. " *
617-
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
618-
"Convert to oneSparseMatrixCSR format instead."
628+
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
629+
"Convert to oneSparseMatrixCSR format instead."
619630
)
620631
)
621632
queue = global_queue(context(A.nzVal), device(A.nzVal))
@@ -627,8 +638,8 @@ function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, nrhs::Int, A
627638
throw(
628639
ArgumentError(
629640
"sparse_optimize_trsm! is not supported for oneSparseMatrixCSC due to Intel oneAPI limitations. " *
630-
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
631-
"Convert to oneSparseMatrixCSR format instead."
641+
"Intel sparse library only supports nontrans operations for triangular matrix operations. " *
642+
"Convert to oneSparseMatrixCSR format instead."
632643
)
633644
)
634645
queue = global_queue(context(A.nzVal), device(A.nzVal))

test/onemkl.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,7 @@ end
10821082
end
10831083

10841084
@testset "SPARSE" begin
1085-
@testset "$T" for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64])
1085+
@testset "$T" for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64])
10861086
@testset "oneSparseMatrixCSR" begin
10871087
for S in (Int32, Int64)
10881088
A = sprand(T, 20, 10, 0.5)
@@ -1116,7 +1116,7 @@ end
11161116

11171117
@testset "sparse gemv" begin
11181118
@testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCOO, oneSparseMatrixCSR, oneSparseMatrixCSC)
1119-
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
1119+
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
11201120
A = sprand(T, 20, 10, 0.5)
11211121
x = transa == 'N' ? rand(T, 10) : rand(T, 20)
11221122
y = transa == 'N' ? rand(T, 20) : rand(T, 10)
@@ -1129,7 +1129,7 @@ end
11291129
beta = rand(T)
11301130
oneMKL.sparse_optimize_gemv!(transa, dA)
11311131
oneMKL.sparse_gemv!(transa, alpha, dA, dx, beta, dy)
1132-
@test alpha * opa(A) * x + beta * y collect(dy)
1132+
@test alpha * opa(A) * x + beta * y collect(dy)
11331133
end
11341134
end
11351135
end
@@ -1174,16 +1174,18 @@ end
11741174
alpha = rand(T)
11751175
beta = rand(T)
11761176
oneMKL.sparse_symv!(uplo, alpha, dA, dx, beta, dy)
1177-
@test alpha * A * x + beta * y collect(dy)
1177+
@test alpha * A * x + beta * y collect(dy)
11781178
end
11791179
end
11801180
end
11811181

11821182
@testset "sparse trmv" begin
11831183
@testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
11841184
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
1185-
for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
1186-
('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)]
1185+
for (uplo, diag, wrapper) in [
1186+
('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
1187+
('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular),
1188+
]
11871189
(transa == 'N') || continue
11881190
A = sprand(T, 10, 10, 0.5)
11891191
x = rand(T, 10)
@@ -1216,8 +1218,9 @@ end
12161218
@testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
12171219
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
12181220
for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
1219-
('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)]
1220-
(transa == 'N') || continue
1221+
('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular),
1222+
]
1223+
(transa == 'N') || continue
12211224
alpha = rand(T)
12221225
A = rand(T, 10, 10) + I
12231226
A = sparse(A)
@@ -1250,8 +1253,10 @@ end
12501253
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
12511254
@testset "transx = $transx" for (transx, opx) in [('N', identity), ('T', transpose), ('C', adjoint)]
12521255
(transx != 'N') && continue
1253-
for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
1254-
('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)]
1256+
for (uplo, diag, wrapper) in [
1257+
('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
1258+
('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular),
1259+
]
12551260
(transa == 'N') || continue
12561261
alpha = rand(T)
12571262
A = rand(T, 10, 10) + I

0 commit comments

Comments
 (0)