Skip to content

Commit dd857b8

Browse files
authored
Return sparse result in kron with one sparse factor (#331)
1 parent 02c2ad6 commit dd857b8

File tree

3 files changed

+35
-70
lines changed

3 files changed

+35
-70
lines changed

src/SparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Base: ReshapedArray, promote_op, setindex_shape_check, to_shape, tail,
1010
using Base.Order: Forward
1111
using LinearAlgebra
1212
using LinearAlgebra: AdjOrTrans, matprod, AbstractQ, HessenbergQ, QRCompactWYQ, QRPackedQ,
13-
LQPackedQ
13+
LQPackedQ, UpperOrLowerTriangular
1414

1515

1616
import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds

src/linalg.jl

Lines changed: 29 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,15 @@ function opnormestinv(A::AbstractSparseMatrixCSC{T}, t::Integer = min(2,maximum(
13151315
end
13161316

13171317
## kron
1318+
const _SparseArraysCSC{T} = Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}
1319+
const _SparseKronArrays = Union{_SpecialArrays, _SparseArrays, AdjOrTrans{<:Any,<:_SparseArraysCSC}}
1320+
1321+
const _Symmetric_SparseKronArrays{T,A<:_SparseKronArrays} = Symmetric{T,A}
1322+
const _Hermitian_SparseKronArrays{T,A<:_SparseKronArrays} = Hermitian{T,A}
1323+
const _Triangular_SparseKronArrays{T,A<:_SparseKronArrays} = UpperOrLowerTriangular{T,A}
1324+
const _Annotated_SparseKronArrays = Union{_Triangular_SparseKronArrays, _Symmetric_SparseKronArrays, _Hermitian_SparseKronArrays}
1325+
const _SparseKronGroup = Union{_SparseKronArrays, _Annotated_SparseKronArrays}
1326+
13181327
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
13191328
mA, nA = size(A); mB, nB = size(B)
13201329
mC, nC = mA*mB, nA*nB
@@ -1353,19 +1362,10 @@ end
13531362
end
13541363
return C
13551364
end
1356-
@inline function kron!(C::SparseMatrixCSC, A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AbstractSparseMatrixCSC)
1357-
return kron!(C, copy(A), B)
1358-
end
1359-
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC})
1360-
return kron!(C, A, copy(B))
1361-
end
1362-
@inline function kron!(C::SparseMatrixCSC, A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC})
1363-
return kron!(C, copy(A), copy(B))
1364-
end
13651365
@inline function kron!(z::SparseVector, x::SparseVector, y::SparseVector)
13661366
@boundscheck length(z) == length(x)*length(y) || throw(DimensionMismatch("length of " *
13671367
"target vector needs to be $(length(x)*length(y)), but has length $(length(z))"))
1368-
nnzx = nnz(x); nnzy = nnz(y);
1368+
nnzx, nnzy = nnz(x), nnz(y)
13691369
nzind = nonzeroinds(z)
13701370
nzval = nonzeros(z)
13711371

@@ -1380,69 +1380,35 @@ end
13801380
end
13811381
return z
13821382
end
1383+
# due to the sparse result type, there is no risk to override dense ⊗ dense here
1384+
@inline function kron!(C::SparseMatrixCSC, A::Union{_SparseKronGroup,_DenseConcatGroup}, B::Union{_SparseKronGroup,_DenseConcatGroup})
1385+
kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
1386+
end
1387+
kron!(C::SparseMatrixCSC, A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = broadcast!(*, C, A, B)
13831388

1384-
# sparse matrix ⊗ sparse matrix
1385-
function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S2}) where {T1,S1,T2,S2}
1386-
mA, nA = size(A); mB, nB = size(B)
1389+
function kron(A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
1390+
mA, nA = size(A)
1391+
mB, nB = size(B)
13871392
mC, nC = mA*mB, nA*nB
1388-
Tv = typeof(one(T1)*one(T2))
1389-
Ti = promote_type(S1,S2)
1393+
Tv = typeof(oneunit(eltype(A))*oneunit(eltype(B)))
1394+
Ti = promote_type(indtype(A), indtype(B))
13901395
C = spzeros(Tv, Ti, mC, nC)
13911396
sizehint!(C, nnz(A)*nnz(B))
13921397
return @inbounds kron!(C, A, B)
13931398
end
1394-
kron(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AbstractSparseMatrixCSC) = kron(copy(A), B)
1395-
kron(A::AbstractSparseMatrixCSC, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = kron(A, copy(B))
1396-
function kron(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC})
1397-
return kron(copy(A), copy(B))
1398-
end
1399-
1400-
# sparse vector ⊗ sparse vector
1401-
function kron(x::SparseVector{T1,S1}, y::SparseVector{T2,S2}) where {T1,S1,T2,S2}
1402-
nnzx = nnz(x); nnzy = nnz(y)
1399+
function kron(x::SparseVector, y::SparseVector)
1400+
nnzx, nnzy = nnz(x), nnz(y)
14031401
nnzz = nnzx*nnzy # number of nonzeros in new vector
1404-
nzind = Vector{promote_type(S1,S2)}(undef, nnzz) # the indices of nonzeros
1405-
nzval = Vector{typeof(one(T1)*one(T2))}(undef, nnzz) # the values of nonzeros
1402+
nzind = Vector{promote_type(indtype(x), indtype(y))}(undef, nnzz) # the indices of nonzeros
1403+
nzval = Vector{typeof(oneunit(eltype(x))*oneunit(eltype(y)))}(undef, nnzz) # the values of nonzeros
14061404
z = SparseVector(length(x)*length(y), nzind, nzval)
14071405
return @inbounds kron!(z, x, y)
14081406
end
1409-
1410-
# sparse matrix ⊗ sparse vector & vice versa
1411-
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, x::SparseVector) = kron!(C, A, SparseMatrixCSC(x))
1412-
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, x::SparseVector, A::AbstractSparseMatrixCSC) = kron!(C, SparseMatrixCSC(x), A)
1413-
1414-
kron(A::AbstractSparseMatrixCSC, x::SparseVector) = kron(A, SparseMatrixCSC(x))
1415-
kron(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, x::SparseVector) =
1416-
kron(copy(A), x)
1417-
kron(x::SparseVector, A::AbstractSparseMatrixCSC) = kron(SparseMatrixCSC(x), A)
1418-
kron(x::SparseVector, A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) =
1419-
kron(x, copy(A))
1420-
1421-
# sparse vec/mat ⊗ vec/mat and vice versa
1422-
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron!(C, A, sparse(B))
1423-
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron!(C, sparse(A), B)
1424-
1425-
kron(A::Union{SparseVector,AbstractSparseMatrixCSC,AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}, B::VecOrMat) =
1426-
kron(A, sparse(B))
1427-
kron(A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC,AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}) =
1428-
kron(sparse(A), B)
1429-
1430-
# sparse vec/mat ⊗ Diagonal etc. and vice versa
1431-
const StructuredMatrix{T} = Union{Bidiagonal{T}, Diagonal{T}, SymTridiagonal{T}, Tridiagonal{T}}
1432-
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::StructuredMatrix{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} =
1433-
kron!(C, sparse(A), B)
1434-
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::StructuredMatrix{S}) where {T<:Number, S<:Number} =
1435-
kron!(C, A, sparse(B))
1436-
1437-
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron!(C, A, sparse(B))
1438-
1439-
kron(A::StructuredMatrix{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}, AdjOrTrans{S,<:SparseVector}, AdjOrTrans{S,<:AbstractSparseMatrixCSC}}) where {T<:Number, S<:Number} =
1440-
kron(sparse(A), B)
1441-
kron(A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}, AdjOrTrans{S,<:SparseVector}, AdjOrTrans{S,<:AbstractSparseMatrixCSC}}, B::StructuredMatrix{S}) where {T<:Number, S<:Number} =
1442-
kron(A, sparse(B))
1443-
1444-
# sparse outer product
1445-
kron!(C::SparseMatrixCSC, A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = broadcast!(*, C, A, B)
1407+
# extend to annotated sparse arrays, but leave out the (dense ⊗ dense)-case
1408+
kron(A::_SparseKronGroup, B::_SparseKronGroup) =
1409+
kron(convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
1410+
kron(A::_SparseKronGroup, B::_DenseConcatGroup) = kron(A, sparse(B))
1411+
kron(A::_DenseConcatGroup, B::_SparseKronGroup) = kron(sparse(A), B)
14461412
kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B
14471413

14481414
## det, inv, cond

test/linalg.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ end
694694
end
695695

696696
@testset "kronecker product" begin
697-
for (m,n) in ((5,10), (13,8), (14,10))
697+
for (m,n) in ((5,10), (13,8))
698698
a = sprand(m, 5, 0.4); a_d = Matrix(a)
699699
b = sprand(n, 6, 0.3); b_d = Matrix(b)
700700
v = view(a, :, 1); v_d = Vector(v)
@@ -718,12 +718,11 @@ end
718718
@test Array(kron(t(a), b)::SparseMatrixCSC) == kron(t(a_d), b_d)
719719
@test Array(kron(a, t(b))::SparseMatrixCSC) == kron(a_d, t(b_d))
720720
@test Array(kron(t(a), t(b))::SparseMatrixCSC) == kron(t(a_d), t(b_d))
721-
@test Array(kron(a_d, t(b))::SparseMatrixCSC) == kron(a_d, t(b_d))
722721
@test Array(kron(t(a), b_d)::SparseMatrixCSC) == kron(t(a_d), b_d)
723-
@test issparse(kron(c, d_di))
724-
@test Array(kron(c, d_di)) == kron(c_d, d_d)
725-
@test issparse(kron(c_di, d))
726-
@test Array(kron(c_di, d)) == kron(c_d, d_d)
722+
@test Array(kron(a_d, t(b))::SparseMatrixCSC) == kron(a_d, t(b_d))
723+
@test Array(kron(t(a), c_di)::SparseMatrixCSC) == kron(t(a_d), c_d)
724+
@test Array(kron(a, t(c_di))::SparseMatrixCSC) == kron(a_d, t(c_d))
725+
@test Array(kron(t(a), t(c_di))::SparseMatrixCSC) == kron(t(a_d), t(c_d))
727726
@test issparse(kron(c_di, y))
728727
@test Array(kron(c_di, y)) == kron(c_di, y_d)
729728
@test issparse(kron(x, d_di))

0 commit comments

Comments
 (0)