diff --git a/src/LinearAlgebra.jl b/src/LinearAlgebra.jl index 8b190e39..10b9551f 100644 --- a/src/LinearAlgebra.jl +++ b/src/LinearAlgebra.jl @@ -326,7 +326,7 @@ StridedMatrixStride1{T} = StridedArrayStride1{T,2} """ LinearAlgebra.checksquare(A) -Check that a matrix is square, then return its common dimension. +Checks whether a matrix is square, returning its common dimension if it is the case, or throwing a DimensionMismatch error otherwise. For multiple arguments, return a vector. # Examples @@ -340,19 +340,13 @@ julia> LinearAlgebra.checksquare(A, B) ``` """ function checksquare(A) - m,n = size(A) - m == n || throw(DimensionMismatch(lazy"matrix is not square: dimensions are $(size(A))")) - m + sizeA = size(A) + length(sizeA) == 2 || throw(DimensionMismatch(lazy"input is not a matrix: dimensions are $sizeA")) + sizeA[1] == sizeA[2] || throw(DimensionMismatch(lazy"matrix is not square: dimensions are $sizeA")) + return sizeA[1] end -function checksquare(A...) - sizes = Int[] - for a in A - size(a,1)==size(a,2) || throw(DimensionMismatch(lazy"matrix is not square: dimensions are $(size(a))")) - push!(sizes, size(a,1)) - end - return sizes -end +checksquare(A...) = [checksquare(a) for a in A] function char_uplo(uplo::Symbol) if uplo === :U diff --git a/test/dense.jl b/test/dense.jl index e8c57bf8..97627008 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -1430,4 +1430,15 @@ end @test log(D) ≈ log(UpperTriangular(D)) end +@testset "issue 1362" begin + A = zeros(2,2) + B = zeros(2,3) + C = zeros(2,2,1) + @test LinearAlgebra.checksquare(A) == 2 + @test LinearAlgebra.checksquare(A,A) == [2, 2] + @test_throws DimensionMismatch LinearAlgebra.checksquare(B) + @test_throws DimensionMismatch LinearAlgebra.checksquare(C) + @test_throws DimensionMismatch LinearAlgebra.checksquare(A,B) +end + end # module TestDense