Skip to content

Commit b17cd17

Browse files
committed
Make dot consistently return zero on empty arrays (#1494)
1 parent a6ae676 commit b17cd17

File tree

4 files changed

+20
-14
lines changed

4 files changed

+20
-14
lines changed

src/generic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,8 @@ function dot(x::AbstractArray, y::AbstractArray)
986986
throw(DimensionMismatch(lazy"first array has length $(lx) which does not match the length of the second, $(length(y))."))
987987
end
988988
if lx == 0
989-
return dot(zero(eltype(x)), zero(eltype(y)))
989+
# make sure the returned result equals exactly the zero element
990+
return zero(dot(zero(eltype(x)), zero(eltype(y))))
990991
end
991992
s = zero(dot(first(x), first(y)))
992993
for (Ix, Iy) in zip(eachindex(x), eachindex(y))

src/symmetric.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
520520
if n != size(B, 2)
521521
throw(DimensionMismatch(lazy"A has dimensions $(size(A)) but B has dimensions $(size(B))"))
522522
end
523-
523+
iszero(n) && return $real(zero(dot(zero(eltype(A)), zero(eltype(B)))))
524524
dotprod = $real(zero(dot(first(A), first(B))))
525525
@inbounds if A.uplo == 'U' && B.uplo == 'U'
526526
for j in 1:n

test/matmul.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -676,23 +676,24 @@ end
676676
@test dot(Z, Z) == convert(elty, 34.0)
677677
end
678678

679-
dot1(x, y) = invoke(dot, Tuple{Any,Any}, x, y)
680-
dot2(x, y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x, y)
681679
@testset "generic dot" begin
680+
dot1(x, y) = invoke(dot, Tuple{Any,Any}, x, y)
681+
dot2(x, y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x, y)
682682
AA = [1+2im 3+4im; 5+6im 7+8im]
683683
BB = [2+7im 4+1im; 3+8im 6+5im]
684684
for A in (copy(AA), view(AA, 1:2, 1:2)), B in (copy(BB), view(BB, 1:2, 1:2))
685685
@test dot(A, B) == dot(vec(A), vec(B)) == dot1(A, B) == dot2(A, B) == dot(float.(A), float.(B))
686-
@test dot(Int[], Int[]) == 0 == dot1(Int[], Int[]) == dot2(Int[], Int[])
687-
@test_throws MethodError dot(Any[], Any[])
688-
@test_throws MethodError dot1(Any[], Any[])
689-
@test_throws MethodError dot2(Any[], Any[])
690-
for n1 = 0:2, n2 = 0:2, d in (dot, dot1, dot2)
691-
if n1 != n2
692-
@test_throws DimensionMismatch d(1:n1, 1:n2)
693-
else
694-
@test d(1:n1, 1:n2) norm(1:n1)^2
695-
end
686+
end
687+
@test dot(Int[], Int[]) == 0 == dot1(Int[], Int[]) == dot2(Int[], Int[])
688+
@test dot(ComplexF64[], Float64[]) === dot(ComplexF64[;;], Float64[;;]) === zero(ComplexF64)
689+
@test_throws MethodError dot(Any[], Any[])
690+
@test_throws MethodError dot1(Any[], Any[])
691+
@test_throws MethodError dot2(Any[], Any[])
692+
for n1 = 0:2, n2 = 0:2, d in (dot, dot1, dot2)
693+
if n1 != n2
694+
@test_throws DimensionMismatch d(1:n1, 1:n2)
695+
else
696+
@test d(1:n1, 1:n2) norm(1:n1)^2
696697
end
697698
end
698699
end

test/symmetric.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ end
470470
@test dot(symblockmu, symblockml) dot(msymblockmu, msymblockml)
471471
@test dot(symblockml, symblockmu) dot(msymblockml, msymblockmu)
472472
@test dot(symblockml, symblockml) dot(msymblockml, msymblockml)
473+
474+
# empty matrices
475+
@test dot(mtype(ComplexF64[;;], :U), mtype(Float64[;;], :U)) === zero(mtype == Hermitian ? Float64 : ComplexF64)
476+
@test dot(mtype(ComplexF64[;;], :L), mtype(Float64[;;], :L)) === zero(mtype == Hermitian ? Float64 : ComplexF64)
473477
end
474478
end
475479

0 commit comments

Comments
 (0)