Skip to content

Commit 78c0f41

Browse files
alystAlexey StukalovKristofferC
authored
sparse * sparse fixes (#55)
* sp*sp: overload for BlasFloat eltype only * tests: fix special{sparse}*special{sparse} wrap tested matrices into Special() calll * SpecialMatrices Tuple * declare sp*sp for all pairs of special matrices * cosmetic fix Co-authored-by: Kristoffer Carlsson <[email protected]> --------- Co-authored-by: Alexey Stukalov <[email protected]> Co-authored-by: Kristoffer Carlsson <[email protected]>
1 parent fec49ef commit 78c0f41

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

src/interface.jl

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ SimpleOrSpecialOrAdjMat{T, M} = Union{M,
2424
AdjOrTranspMat{T, <:SpecialMat{T,<:M}},
2525
SpecialMat{T,<:AdjOrTranspMat{T,<:M}}}
2626

27+
const SpecialMatrices = (LowerTriangular, UpperTriangular,
28+
UnitLowerTriangular, UnitUpperTriangular,
29+
Symmetric, Hermitian)
30+
2731
# unwraps matrix A from Adjoint/Transpose transform
2832
unwrap_trans(A::AbstractMatrix) = A
2933
unwrap_trans(A::Union{Adjoint, Transpose}) = unwrap_trans(parent(A))
@@ -219,17 +223,45 @@ end
219223
# sparse * sparse overloads, have to be more specific than
220224
# the ones in SparseArrays.jl to avoid ambiguity
221225

222-
(*)(A::SparseMat{T}, B::SparseMat{T}) where T =
223-
spmatmul_sparse(A, B)
226+
for Amat in (nothing, SpecialMatrices...), Bmat in (nothing, SpecialMatrices...)
227+
Atype = !isnothing(Amat) ? :($Amat{T,S}) : :S
228+
tAtype = !isnothing(Amat) ? :($Amat{T, <:AdjOrTranspMat{T, S}}) : nothing
229+
Btype = !isnothing(Bmat) ? :($Bmat{T,S}) : :S
230+
tBtype = !isnothing(Bmat) ? :($Bmat{T, <:AdjOrTranspMat{T, S}}) : nothing
231+
232+
@eval (*)(A::$Atype, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
233+
spmatmul_sparse(A, B)
234+
235+
@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
236+
spmatmul_sparse(A, B)
224237

225-
(*)(A::AdjOrTranspMat{T, S}, B::S) where {T <: BlasFloat, S <: SparseMat{T}} =
226-
spmatmul_sparse(A, B)
238+
@eval (*)(A::$Atype, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
239+
spmatmul_sparse(A, B)
227240

228-
(*)(A::S, B::AdjOrTranspMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
229-
spmatmul_sparse(A, B)
241+
@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
242+
spmatmul_sparse(A, B)
243+
244+
if tAtype !== nothing
245+
@eval (*)(A::$tAtype, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
246+
spmatmul_sparse(A, B)
247+
248+
@eval (*)(A::$tAtype, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
249+
spmatmul_sparse(A, B)
250+
end
230251

231-
(*)(A::AdjOrTranspMat{T, S}, B::AdjOrTranspMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
232-
spmatmul_sparse(A, B)
252+
if tBtype !== nothing
253+
@eval (*)(A::$Atype, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
254+
spmatmul_sparse(A, B)
255+
256+
@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
257+
spmatmul_sparse(A, B)
258+
end
259+
260+
if tAtype !== nothing && tBtype !== nothing
261+
@eval (*)(A::$tAtype, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
262+
spmatmul_sparse(A, B)
263+
end
264+
end
233265

234266
if VERSION < v"1.11" # in 1.11 these wrappers are already defined in LinearAlgebra
235267

@@ -245,9 +277,7 @@ function (\)(A::Union{S, AdjOrTranspMat{T, S}}, B::StridedMatrix{T}) where {T <:
245277
return ldiv!(C, A, B)
246278
end
247279

248-
for mat in (LowerTriangular, UpperTriangular,
249-
UnitLowerTriangular, UnitUpperTriangular,
250-
Symmetric, Hermitian)
280+
for mat in SpecialMatrices
251281

252282
@eval function (\)(A::Union{$mat{T, S}, AdjOrTranspMat{T, $mat{T, S}}, $mat{T, <:AdjOrTranspMat{T, S}}},
253283
x::StridedVector{T}

test/test_BLAS.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,8 @@ end
728728
n = rand(10:50)
729729
spf = 0.1 + 0.8 * rand()
730730

731-
spA = convert_to_Aclass(sparserandn(SPMT{T, IT}, n, n, spf))
732-
spB = convert_to_Bclass(sparserandn(SPMT{T, IT}, n, n, spf))
731+
spA = Aclass(convert_to_Aclass(sparserandn(SPMT{T, IT}, n, n, spf)))
732+
spB = Bclass(convert_to_Bclass(sparserandn(SPMT{T, IT}, n, n, spf)))
733733
A = convert(Matrix, spA)
734734
B = convert(Matrix, spB)
735735

0 commit comments

Comments
 (0)