Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 245 additions & 4 deletions lib/mkl/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
return sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCOO{T}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

Expand All @@ -18,8 +23,14 @@ function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseM
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
return sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCOO{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end
Expand All @@ -31,3 +42,233 @@ end
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
end

# Handle Transpose and Adjoint wrappers for sparse matrices
# Let the low-level wrappers handle the CSC->CSR conversion and flip_trans logic

# Matrix-vector multiplication with transpose/adjoint
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michel2323 In the case tA = 'C', you need to do a product with the conjugate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if you didn't mixed BlasReal with BlasFloat?

return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
if tA == 'T'
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
C .= conj.(C)
B .= conj.(B)
else
tA_final = tA == 'N' ? 'C' : 'N'
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
end
return C
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar issue, you need a special handle of tA = 'C'.

return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
if tA == 'T'
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
C .= conj.(C)
B .= conj.(B)
else
tA_final = tA == 'N' ? 'C' : 'N'
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
end
return C
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once again, special handle of tA = 'C' is needed (when the matrix is complex).

return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
if tA == 'T'
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
C .= conj.(C)
B .= conj.(B)
else
tA_final = tA == 'N' ? 'C' : 'N'
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
end
return C
end

# Handle Transpose{T, Adjoint{T, ...}} for complex matrices
# transpose(adjoint(A)) for complex matrices needs special handling
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
# transpose(adjoint(A)) = conj(A), so we need to conjugate
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
if tA == 'N'
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
elseif tA == 'T'
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
else # tA == 'C'
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
end
C .= conj.(C)
B .= conj.(B)
return C
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
# transpose(adjoint(A)) = conj(A), so we need to conjugate
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
if tA == 'N'
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
elseif tA == 'T'
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
else # tA == 'C'
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
end
C .= conj.(C)
B .= conj.(B)
return C
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
# transpose(adjoint(A)) = conj(A), so we need to conjugate
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
if tA == 'N'
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
elseif tA == 'T'
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
else # tA == 'C'
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
end
C .= conj.(C)
B .= conj.(B)
return C
end

# Custom * operators for Transpose{T, Adjoint{T, ...}} to ensure correct output size allocation
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, x::oneVector{T}) where {T <: BlasComplex}
m, n = size(A)
y = similar(x, T, m)
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
return y
end

function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, x::oneVector{T}) where {T <: BlasComplex}
m, n = size(A)
y = similar(x, T, m)
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
return y
end

function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, x::oneVector{T}) where {T <: BlasComplex}
m, n = size(A)
y = similar(x, T, m)
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
return y
end

# Matrix-matrix multiplication with transpose/adjoint
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
if tA == 'T'
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
C .= conj.(C)
B .= conj.(B)
else
tA_final = tA == 'N' ? 'C' : 'N'
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
end
return C
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
if tA == 'T'
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
C .= conj.(C)
B .= conj.(B)
else
tA_final = tA == 'N' ? 'C' : 'N'
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
end
return C
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
if tA == 'T'
alpha = _add.alpha
beta = _add.beta
B .= conj.(B)
C .= conj.(C)
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
C .= conj.(C)
B .= conj.(B)
else
tA_final = tA == 'N' ? 'C' : 'N'
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
end
return C
end
1 change: 0 additions & 1 deletion lib/mkl/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,5 @@ end
ptrs = pointer.(batch)
return oneArray(ptrs)
end

flip_trans(trans::Char) = trans == 'N' ? 'T' : 'N'
flip_uplo(uplo::Char) = uplo == 'L' ? 'U' : 'L'
13 changes: 11 additions & 2 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1130,9 +1130,16 @@ end
oneMKL.sparse_optimize_gemv!(transa, dA)
oneMKL.sparse_gemv!(transa, alpha, dA, dx, beta, dy)
@test alpha * opa(A) * x + beta * y ≈ collect(dy)
end
end
dy = oneVector{T}(y)
@test alpha * opa(A) * x + beta * y ≈ Array(alpha * opa(dA) * dx + beta * dy)
tx = transa == 'N' ? rand(T, 20) : rand(T, 10)
ty = transa == 'N' ? rand(T, 10) : rand(T, 20)
dtx = oneVector{T}(tx)
dty = oneVector{T}(ty)
t = @test alpha * opa(A') * tx + beta * ty ≈ Array(alpha * opa(dA') * dtx + beta * dty)
end
end
end

@testset "sparse gemm" begin
@testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
Expand All @@ -1153,6 +1160,8 @@ end
oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC)

@test alpha * opa(A) * opb(B) + beta * C ≈ collect(dC)
dC = oneMatrix{T}(C)
@test alpha * opa(A) * opb(B) + beta * C ≈ Array(alpha * opa(dA) * opb(dB) + beta * dC)
oneMKL.sparse_optimize_gemm!(transa, transb, 2, dA)
end
end
Expand Down