Skip to content

Commit bfcaad6

Browse files
committed
test dimension
1 parent 7833f14 commit bfcaad6

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

src/linalg.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -630,12 +630,12 @@ function dot(A::AbstractSparseMatrixCSC, B::Union{DenseMatrixUnion,WrapperMatrix
630630
end
631631

632632
function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector)
633-
if length(x) != length(y)
633+
d = Q.diag
634+
if length(x) != length(y) || length(y) != length(d)
634635
throw(
635-
DimensionMismatch("Vector x has a length $(length(x)) but y has a length $(length(y))")
636+
DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))")
636637
)
637638
end
638-
d = Q.diag
639639
nzvals = nonzeros(x)
640640
nzinds = nonzeroinds(x)
641641
s = zero(Base.promote_eltype(x, Q, y))
@@ -645,40 +645,40 @@ function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector)
645645
return s
646646
end
647647

648-
function dot(a::AbstractSparseVector, Q::Diagonal, b::AbstractSparseVector)
649-
n = length(a)
650-
if length(b) != n
648+
function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractSparseVector)
649+
n = length(x)
650+
if length(y) != n || n != size(Q, 1)
651651
throw(
652-
DimensionMismatch("Vector a has a length $n but b has a length $(length(b))")
652+
DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))")
653653
)
654654
end
655-
anzind = nonzeroinds(a)
656-
bnzind = nonzeroinds(b)
657-
anzval = nonzeros(a)
658-
bnzval = nonzeros(b)
659-
s = zero(Base.promote_eltype(a, Q, b))
655+
xnzind = nonzeroinds(x)
656+
ynzind = nonzeroinds(y)
657+
xnzval = nonzeros(x)
658+
ynzval = nonzeros(y)
659+
s = zero(Base.promote_eltype(x, Q, y))
660660

661-
if isempty(anzind) || isempty(bnzind)
661+
if isempty(xnzind) || isempty(ynzind)
662662
return s
663663
end
664664

665-
a_idx = 1
666-
b_idx = 1
667-
a_idx_last = length(anzind)
668-
b_idx_last = length(bnzind)
665+
x_idx = 1
666+
y_idx = 1
667+
x_idx_last = length(xnzind)
668+
y_idx_last = length(ynzind)
669669

670670
# go through the nonzero indices of a and b simultaneously
671-
@inbounds while a_idx <= a_idx_last && b_idx <= b_idx_last
672-
ia = anzind[a_idx]
673-
ib = bnzind[b_idx]
674-
if ia == ib
675-
s += dot(anzval[a_idx], Q.diag[ia], bnzval[b_idx])
676-
a_idx += 1
677-
b_idx += 1
678-
elseif ia < ib
679-
a_idx += 1
671+
@inbounds while x_idx <= x_idx_last && y_idx <= y_idx_last
672+
ix = xnzind[x_idx]
673+
iy = ynzind[y_idx]
674+
if ix == iy
675+
s += dot(xnzval[x_idx], Q.diag[ix], ynzval[y_idx])
676+
x_idx += 1
677+
y_idx += 1
678+
elseif ix < iy
679+
x_idx += 1
680680
else
681-
b_idx += 1
681+
y_idx += 1
682682
end
683683
end
684684
return s

test/linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,9 @@ end
860860
@test dot(b, D, a) dot(b, sparse(D), a)
861861
@test dot(b, D, a) dot(b, D, collect(a))
862862
@test dot(b, D, a) dot(collect(b), D, a)
863+
@test_throws DimensionMismatch dot(b, D, [a; 1])
864+
@test_throws DimensionMismatch dot([b; 1], D, a)
865+
@test_throws DimensionMismatch dot([b; 1], D, [a; 1])
863866
end
864867
end
865868
@test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2))

0 commit comments

Comments
 (0)