Skip to content

Commit 997ce45

Browse files
ranochaViralBShah
authored andcommitted
specialise kron for sparse mat/vec and Diagonal -> sparse matrix (#32793)
1 parent 8272763 commit 997ce45

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

stdlib/SparseArrays/src/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,10 @@ kron(x::SparseVector, A::SparseMatrixCSC) = kron(SparseMatrixCSC(x), A)
13071307
kron(A::Union{SparseVector,SparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B))
13081308
kron(A::VecOrMat, B::Union{SparseVector,SparseMatrixCSC}) = kron(sparse(A), B)
13091309

1310+
# sparse vec/mat ⊗ Diagonal and vice versa
1311+
kron(A::Diagonal{T}, B::Union{SparseVector{S}, SparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron(sparse(A), B)
1312+
kron(A::Union{SparseVector{T}, SparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron(A, sparse(B))
1313+
13101314
# sparse outer product
13111315
kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B
13121316

stdlib/SparseArrays/test/sparse.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,20 @@ end
374374
v = view(a, :, 1); v_d = Vector(v)
375375
x = sprand(m, 0.4); x_d = Vector(x)
376376
y = sprand(n, 0.3); y_d = Vector(y)
377+
c_di = Diagonal(rand(m)); c = sparse(c_di); c_d = Array(c_di)
378+
d_di = Diagonal(rand(n)); d = sparse(d_di); d_d = Array(d_di)
377379
# mat ⊗ mat
378380
@test Array(kron(a, b)) == kron(a_d, b_d)
379381
@test Array(kron(a_d, b)) == kron(a_d, b_d)
380382
@test Array(kron(a, b_d)) == kron(a_d, b_d)
383+
@test issparse(kron(c, d_di))
384+
@test Array(kron(c, d_di)) == kron(c_d, d_d)
385+
@test issparse(kron(c_di, d))
386+
@test Array(kron(c_di, d)) == kron(c_d, d_d)
387+
@test issparse(kron(c_di, y))
388+
@test Array(kron(c_di, y)) == kron(c_di, y_d)
389+
@test issparse(kron(x, d_di))
390+
@test Array(kron(x, d_di)) == kron(x_d, d_di)
381391
# vec ⊗ vec
382392
@test Vector(kron(x, y)) == kron(x_d, y_d)
383393
@test Vector(kron(x_d, y)) == kron(x_d, y_d)

0 commit comments

Comments
 (0)