diff --git a/lib/mkl/interfaces.jl b/lib/mkl/interfaces.jl index 343131d9..9b27d68a 100644 --- a/lib/mkl/interfaces.jl +++ b/lib/mkl/interfaces.jl @@ -7,8 +7,13 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A:: sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C) end -function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal - tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA) +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + return sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C) +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCOO{T}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C) end @@ -18,8 +23,14 @@ function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseM sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C) end -function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal - tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA) +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + return sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C) +end + +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCOO{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C) end @@ -31,3 +42,233 @@ end function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C) end + +# Handle Transpose and Adjoint wrappers for sparse matrices +# Let the low-level wrappers handle the CSC->CSR conversion and flip_trans logic + +# Matrix-vector multiplication with transpose/adjoint +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C') + return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C) +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + if tA == 'T' + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C) + C .= conj.(C) + B .= conj.(B) + else + tA_final = tA == 'N' ? 'C' : 'N' + sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C) + end + return C +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C') + return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C) +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + if tA == 'T' + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C) + C .= conj.(C) + B .= conj.(B) + else + tA_final = tA == 'N' ? 'C' : 'N' + sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C) + end + return C +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C') + return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C) +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + if tA == 'T' + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C) + C .= conj.(C) + B .= conj.(B) + else + tA_final = tA == 'N' ? 'C' : 'N' + sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C) + end + return C +end + +# Handle Transpose{T, Adjoint{T, ...}} for complex matrices +# transpose(adjoint(A)) for complex matrices needs special handling +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + # transpose(adjoint(A)) = conj(A), so we need to conjugate + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + if tA == 'N' + sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C) + elseif tA == 'T' + sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C) + else # tA == 'C' + sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C) + end + C .= conj.(C) + B .= conj.(B) + return C +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + # transpose(adjoint(A)) = conj(A), so we need to conjugate + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + if tA == 'N' + sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C) + elseif tA == 'T' + sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C) + else # tA == 'C' + sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C) + end + C .= conj.(C) + B .= conj.(B) + return C +end + +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + # transpose(adjoint(A)) = conj(A), so we need to conjugate + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + if tA == 'N' + sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C) + elseif tA == 'T' + sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C) + else # tA == 'C' + sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C) + end + C .= conj.(C) + B .= conj.(B) + return C +end + +# Custom * operators for Transpose{T, Adjoint{T, ...}} to ensure correct output size allocation +function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, x::oneVector{T}) where {T <: BlasComplex} + m, n = size(A) + y = similar(x, T, m) + LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T))) + return y +end + +function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, x::oneVector{T}) where {T <: BlasComplex} + m, n = size(A) + y = similar(x, T, m) + LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T))) + return y +end + +function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, x::oneVector{T}) where {T <: BlasComplex} + m, n = size(A) + y = similar(x, T, m) + LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T))) + return y +end + +# Matrix-matrix multiplication with transpose/adjoint +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C') + return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C) +end + +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + if tA == 'T' + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C) + C .= conj.(C) + B .= conj.(B) + else + tA_final = tA == 'N' ? 'C' : 'N' + sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C) + end + return C +end + +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C') + return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C) +end + +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + if tA == 'T' + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C) + C .= conj.(C) + B .= conj.(B) + else + tA_final = tA == 'N' ? 'C' : 'N' + sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C) + end + return C +end + +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C') + return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C) +end + +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + if tA == 'T' + alpha = _add.alpha + beta = _add.beta + B .= conj.(B) + C .= conj.(C) + sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C) + C .= conj.(C) + B .= conj.(B) + else + tA_final = tA == 'N' ? 'C' : 'N' + sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C) + end + return C +end diff --git a/lib/mkl/utils.jl b/lib/mkl/utils.jl index ba8ddbcb..1ac75bca 100644 --- a/lib/mkl/utils.jl +++ b/lib/mkl/utils.jl @@ -113,6 +113,5 @@ end ptrs = pointer.(batch) return oneArray(ptrs) end - flip_trans(trans::Char) = trans == 'N' ? 'T' : 'N' flip_uplo(uplo::Char) = uplo == 'L' ? 'U' : 'L' diff --git a/test/onemkl.jl b/test/onemkl.jl index 8881401c..24410dac 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -1130,9 +1130,16 @@ end oneMKL.sparse_optimize_gemv!(transa, dA) oneMKL.sparse_gemv!(transa, alpha, dA, dx, beta, dy) @test alpha * opa(A) * x + beta * y ≈ collect(dy) - end - end + dy = oneVector{T}(y) + @test alpha * opa(A) * x + beta * y ≈ Array(alpha * opa(dA) * dx + beta * dy) + tx = transa == 'N' ? rand(T, 20) : rand(T, 10) + ty = transa == 'N' ? rand(T, 10) : rand(T, 20) + dtx = oneVector{T}(tx) + dty = oneVector{T}(ty) + t = @test alpha * opa(A') * tx + beta * ty ≈ Array(alpha * opa(dA') * dtx + beta * dty) + end end + end @testset "sparse gemm" begin @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC) @@ -1153,6 +1160,8 @@ end oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC) @test alpha * opa(A) * opb(B) + beta * C ≈ collect(dC) + dC = oneMatrix{T}(C) + @test alpha * opa(A) * opb(B) + beta * C ≈ Array(alpha * opa(dA) * opb(dB) + beta * dC) oneMKL.sparse_optimize_gemm!(transa, transb, 2, dA) end end