Skip to content

Commit 84c9298

Browse files
authored
Add bounds checks to inplace kron! (#330)
1 parent c7ad0b9 commit 84c9298

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/linalg.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,8 @@ end
13181318
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
13191319
mA, nA = size(A); mB, nB = size(B)
13201320
mC, nC = mA*mB, nA*nB
1321-
1321+
@boundscheck size(C) == (mC, nC) || throw(DimensionMismatch("target matrix needs to have size ($mC, $nC)," *
1322+
" but has size $(size(C))"))
13221323
rowvalC = rowvals(C)
13231324
nzvalC = nonzeros(C)
13241325
colptrC = getcolptr(C)
@@ -1362,6 +1363,8 @@ end
13621363
return kron!(C, copy(A), copy(B))
13631364
end
13641365
@inline function kron!(z::SparseVector, x::SparseVector, y::SparseVector)
1366+
@boundscheck length(z) == length(x)*length(y) || throw(DimensionMismatch("length of " *
1367+
"target vector needs to be $(length(x)*length(y)), but has length $(length(z))"))
13651368
nnzx = nnz(x); nnzy = nnz(y);
13661369
nzind = nonzeroinds(z)
13671370
nzval = nonzeros(z)

test/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,10 @@ end
754754
@test Vector(kron(x, z)) == kron(x_d, z_d)
755755
@test Array(kron(a, z)) == kron(a_d, z_d)
756756
@test Array(kron(z, b)) == kron(z_d, b_d)
757+
# test bounds checks
758+
@test_throws DimensionMismatch kron!(copy(a), a, b)
759+
@test_throws DimensionMismatch kron!(copy(x), x, y)
760+
@test_throws DimensionMismatch kron!(spzeros(2,2), x, y')
757761
end
758762
end
759763

0 commit comments

Comments
 (0)