diff --git a/src/algorithms.jl b/src/algorithms.jl index d4b2bb0a..ea260aa2 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -143,3 +143,39 @@ macro functiondef(f) Core.@__doc__ $f, $f! end) end + +""" + @check_scalar(x, y, [op], [eltype]) + +Check if `eltype(x) == op(eltype(y))` and throw an error if not. +By default `op = identity` and `eltype = eltype'. +""" +macro check_scalar(x, y, op=:identity, eltype=:eltype) + error_message = "Unexpected scalar type: " + error_message *= string(eltype) * "(" * string(x) * ")" + if op == :identity + error_message *= " != " * string(eltype) * "(" * string(y) * ")" + else + error_message *= " != " * string(op) * "(" * string(eltype) * "(" * string(y) * "))" + end + return esc(quote + $eltype($x) == $op($eltype($y)) || throw(ArgumentError($error_message)) + end) +end + +""" + @check_size(x, sz, [size]) + +Check if `size(x) == sz` and throw an error if not. +By default, `size = size`. +""" +macro check_size(x, sz, size=:size) + msgstart = string(size) * "(" * string(x) * ") = " + err = gensym() + return esc(quote + szx = $size($x) + $err = $msgstart * string(szx) * " instead of expected value " * + string($sz) + szx == $sz || throw(DimensionMismatch($err)) + end) +end \ No newline at end of file diff --git a/src/implementations/decompositions.jl b/src/implementations/decompositions.jl index e6880f6b..d886588a 100644 --- a/src/implementations/decompositions.jl +++ b/src/implementations/decompositions.jl @@ -10,19 +10,19 @@ # QR, LQ, QL, RQ Decomposition # ---------------------------- """ - LAPACK_HoudeholderQR(; blocksize, positive = false, pivoted = false) + LAPACK_HouseholderQR(; blocksize, positive = false, pivoted = false) Algorithm type to denote the standard LAPACK algorithm for computing the QR decomposition of a matrix using Householder reflectors. The specific LAPACK function can be controlled using the keyword arugments, i.e. `?geqrt` will be chosen if `blocksize > 1`. With `blocksize == 1`, `?geqrf` will be chosen if `pivoted == false` and `?geqp3` will be chosen -if `pivoted == true`. The keyword `positive =true` can be used to ensure that the diagonal +if `pivoted == true`. The keyword `positive=true` can be used to ensure that the diagonal elements of `R` are non-negative. """ @algdef LAPACK_HouseholderQR """ - LAPACK_HoudeholderLQ(; blocksize, positive = false) + LAPACK_HouseholderLQ(; blocksize, positive = false) Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of a matrix using Householder reflectors. The specific LAPACK function can be controlled using diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index b3d2f6a0..049438a7 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -10,21 +10,21 @@ copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A) function check_input(::typeof(eig_full!), A::AbstractMatrix, DV) m, n = size(A) - m == n || throw(ArgumentError("Eigenvalue decomposition requires square input matrix")) + m == n || throw(DimensionMismatch("square input matrix expected")) D, V = DV - Tc = complex(eltype(A)) - (V isa AbstractMatrix && eltype(V) == Tc && size(V) == (m, m)) || - throw(ArgumentError("`eig_full!` requires square matrix V with same size as A and complex `eltype`")) - (D isa Diagonal && eltype(D) == Tc && size(D) == (m, m)) || - throw(ArgumentError("`eig_full!` requires Diagonal matrix D with same size as A and complex `eltype`")) + @assert D isa Diagonal && V isa AbstractMatrix + @check_size(D, (m, m)) + @check_scalar(D, A, complex) + @check_size(V, (m, m)) + @check_scalar(V, A, complex) return nothing end function check_input(::typeof(eig_vals!), A::AbstractMatrix, D) m, n = size(A) - m == n || throw(ArgumentError("Eigenvalue decomposition requires square input matrix")) - Tc = complex(eltype(A)) - size(D) == (n,) && eltype(D) == Tc || - throw(ArgumentError("Eigenvalue vector `D` must have length equal to size(A, 1) and complex `eltype`")) + m == n || throw(DimensionMismatch("square input matrix expected")) + @assert D isa AbstractVector + @check_size(D, (n,)) + @check_scalar(D, A, complex) return nothing end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 228e6558..53837ae5 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -10,19 +10,20 @@ copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A) function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV) m, n = size(A) - m == n || throw(ArgumentError("Eigenvalue decompsition requires square input matrix")) + m == n || throw(DimensionMismatch("square input matrix expected")) D, V = DV - (V isa AbstractMatrix && eltype(V) == eltype(A) && size(V) == (m, m)) || - throw(ArgumentError("`eigh_full!` requires square V matrix with same size and `eltype` as A")) - (D isa Diagonal && eltype(D) == real(eltype(A)) && size(D) == (m, m)) || - throw(ArgumentError("`eigh_full!` requires Diagonal matrix D with same size as A with a real `eltype`")) + @assert D isa Diagonal && V isa AbstractMatrix + @check_size(D, (m, m)) + @check_scalar(D, A, real) + @check_size(V, (m, m)) + @check_scalar(V, A) return nothing end function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D) m, n = size(A) - m == n || throw(ArgumentError("Eigenvalue decompsition requires square input matrix")) - (size(D) == (n,) && eltype(D) == real(eltype(A))) || - throw(ArgumentError("Eigenvalue vector `D` must have length equal to size(A, 1) with a real `eltype`")) + @assert D isa AbstractVector + @check_size(D, (n,)) + @check_scalar(D, A, real) return nothing end diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index ea869c87..165c63b5 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -13,30 +13,30 @@ end function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ) m, n = size(A) L, Q = LQ - (Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (n, n)) || - throw(DimensionMismatch("Full unitary matrix Q must be square with equal number of columns as A")) - (L isa AbstractMatrix && eltype(L) == eltype(A) && (isempty(L) || size(L) == (m, n))) || - throw(DimensionMismatch("Lower triangular matrix L must have size equal to A")) + @assert L isa AbstractMatrix && Q isa AbstractMatrix + isempty(L) || @check_size(L, (m, n)) + @check_scalar(L, A) + @check_size(Q, (n, n)) + @check_scalar(Q, A) return nothing end function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ) m, n = size(A) - if m <= n - L, Q = LQ - (Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, n)) || - throw(DimensionMismatch("Isometric Q must have size equal to A")) - (L isa AbstractMatrix && eltype(L) == eltype(A) && - (isempty(L) || size(L) == (m, m))) || - throw(DimensionMismatch("Lower triangular matrix L must be square with equal number of columns as A")) - else - check_input(lq_full!, A, LQ) - end + minmn = min(m, n) + L, Q = LQ + @assert L isa AbstractMatrix && Q isa AbstractMatrix + isempty(L) || @check_size(L, (m, minmn)) + @check_scalar(L, A) + @check_size(Q, (minmn, n)) + @check_scalar(Q, A) + return nothing end function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ) m, n = size(A) minmn = min(m, n) - (Nᴴ isa AbstractMatrix && eltype(Nᴴ) == eltype(A) && size(Nᴴ) == (n - minmn, n)) || - throw(DimensionMismatch("Matrix Nᴴ must have a the same eltype as A and a size such that [A; Nᴴ] is square")) + @assert Nᴴ isa AbstractMatrix + @check_size(Nᴴ, (n - minmn, n)) + @check_scalar(Nᴴ, A) return nothing end diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 37100610..e134121b 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -9,37 +9,43 @@ function check_input(::typeof(left_orth!), A::AbstractMatrix, VC) m, n = size(A) minmn = min(m, n) V, C = VC - (V isa AbstractMatrix && eltype(V) == eltype(A) && size(V) == (m, minmn)) || - throw(DimensionMismatch("Isometric V must have the same eltype as A, the same number of rows and min(m, n) columns")) - (C isa AbstractMatrix && eltype(C) == eltype(A) && - (isempty(C) || size(C) == (minmn, n))) || - throw(DimensionMismatch("Corestriction C must have the same eltype as A, the same number of columns and min(m, n) rows")) + @assert V isa AbstractMatrix && C isa AbstractMatrix + @check_size(V, (m, minmn)) + @check_scalar(V, A) + if !isempty(C) + @check_size(C, (minmn, n)) + @check_scalar(C, A) + end return nothing end function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ) m, n = size(A) minmn = min(m, n) C, Vᴴ = CVᴴ - (Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (minmn, n)) || - throw(DimensionMismatch("Adjoint isometric matrix Vᴴ must have the same eltype as A, the same number of columns and min(m, n) rows")) - (C isa AbstractMatrix && eltype(C) == eltype(A) && - (isempty(C) || size(C) == (m, minmn))) || - throw(DimensionMismatch("Corestriction C must have the same eltype as A, the same number of rows and min(m, n) columns")) + @assert C isa AbstractMatrix && Vᴴ isa AbstractMatrix + if !isempty(C) + @check_size(C, (m, minmn)) + @check_scalar(C, A) + end + @check_size(Vᴴ, (minmn, n)) + @check_scalar(Vᴴ, A) return nothing end function check_input(::typeof(left_null!), A::AbstractMatrix, N) m, n = size(A) minmn = min(m, n) - (N isa AbstractMatrix && eltype(N) == eltype(A) && size(N) == (m, m - minmn)) || - throw(DimensionMismatch("Isometric matrix must have the same eltype as A, the same number of rows and m - min(m, n) columns")) + @assert N isa AbstractMatrix + @check_size(N, (m, m - minmn)) + @check_scalar(N, A) return nothing end function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ) m, n = size(A) minmn = min(m, n) - (Nᴴ isa AbstractMatrix && eltype(Nᴴ) == eltype(A) && size(Nᴴ) == (n - minmn, n)) || - throw(DimensionMismatch("Adjoint isometric matrix Nᴴ must have the same eltype as A, the same number of columns and n - min(m, n) rows")) + @assert Nᴴ isa AbstractMatrix + @check_size(Nᴴ, (n - minmn, n)) + @check_scalar(Nᴴ, A) return nothing end diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index b4311f08..2604aab5 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -7,22 +7,24 @@ function check_input(::typeof(left_polar!), A::AbstractMatrix, WP) m, n = size(A) W, P = WP m >= n || - throw(ArgumentError("`left_polar!` requires a matrix A with at least as many rows as columns")) - (W isa AbstractMatrix && eltype(W) == eltype(A) && size(W) == (m, n)) || - throw(ArgumentError("`left_polar!` requires a matrix W with the same size and eltype as A")) - (P isa AbstractMatrix && eltype(P) == eltype(A) && size(P) == (n, n)) || - throw(ArgumentError("`left_polar!` requires a square matrix P with the same eltype and number of columns as A")) + throw(ArgumentError("input matrix needs at least as many rows as columns")) + @assert W isa AbstractMatrix && P isa AbstractMatrix + @check_size(W, (m, n)) + @check_scalar(W, A) + @check_size(P, (n, n)) + @check_scalar(P, A) return nothing end function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ) m, n = size(A) P, Wᴴ = PWᴴ n >= m || - throw(ArgumentError("`right_polar!` requires a matrix A with at least as many columns as rows")) - (P isa AbstractMatrix && eltype(P) == eltype(A) && size(P) == (m, m)) || - throw(ArgumentError("`right_polar!` requires a square matrix P with the same eltype and number of rows as A")) - (Wᴴ isa AbstractMatrix && eltype(Wᴴ) == eltype(A) && size(Wᴴ) == (m, n)) || - throw(ArgumentError("`right_polar!` requires a matrix Wᴴ with the same size and eltype as A")) + throw(ArgumentError("input matrix needs at least as many columns as rows")) + @assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix + @check_size(P, (m, m)) + @check_scalar(P, A) + @check_size(Wᴴ, (m, n)) + @check_scalar(Wᴴ, A) return nothing end diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index f8198755..7e2e13ea 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -13,30 +13,30 @@ end function check_input(::typeof(qr_full!), A::AbstractMatrix, QR) m, n = size(A) Q, R = QR - (Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, m)) || - throw(DimensionMismatch("Full unitary matrix Q must be square with equal number of rows as A")) - (R isa AbstractMatrix && eltype(R) == eltype(A) && (isempty(R) || size(R) == (m, n))) || - throw(DimensionMismatch("Upper triangular matrix R must have size equal to A")) + @assert Q isa AbstractMatrix && R isa AbstractMatrix + @check_size(Q, (m, m)) + @check_scalar(Q, A) + isempty(R) || @check_size(R, (m, n)) + @check_scalar(R, A) return nothing end function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR) m, n = size(A) - if n <= m - Q, R = QR - (Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, n)) || - throw(DimensionMismatch("Isometric Q must have size equal to A")) - (R isa AbstractMatrix && eltype(R) == eltype(A) && - (isempty(R) || size(R) == (n, n))) || - throw(DimensionMismatch("Upper triangular matrix R must be square with equal number of columns as A")) - else - check_input(qr_full!, A, QR) - end + minmn = min(m, n) + Q, R = QR + @assert Q isa AbstractMatrix && R isa AbstractMatrix + @check_size(Q, (m, minmn)) + @check_scalar(Q, A) + isempty(R) || @check_size(R, (minmn, n)) + @check_scalar(R, A) + return nothing end function check_input(::typeof(qr_null!), A::AbstractMatrix, N) m, n = size(A) minmn = min(m, n) - (N isa AbstractMatrix && eltype(N) == eltype(A) && size(N) == (m, m - minmn)) || - throw(DimensionMismatch("Matrix N must have a the same eltype as A and a size such that [A N] is square")) + @assert N isa AbstractMatrix + @check_size(N, (m, m - minmn)) + @check_scalar(N, A) return nothing end diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 828c16eb..55a1bdfa 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -6,21 +6,23 @@ copy_input(::typeof(schur_vals), A::AbstractMatrix) = copy_input(eig_vals, A) # check input function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv) m, n = size(A) - m == n || throw(ArgumentError("Schur decompsition requires square input matrix")) + m == n || throw(DimensionMismatch("square input matrix expected")) T, Z, vals = TZv - (Z isa AbstractMatrix && eltype(Z) == eltype(A) && size(Z) == (m, m)) || - throw(ArgumentError("`schur_full!` requires square Z matrix with same size and `eltype` as A")) - (T isa AbstractMatrix && eltype(T) == eltype(A) && size(T) == (m, m)) || - throw(ArgumentError("`schur_full!` requires square T matrix with same size and `eltype` as A")) - size(vals) == (n,) && eltype(vals) == complex(eltype(A)) || - throw(ArgumentError("Eigenvalue vector `vals` must have length equal to size(A, 1) and complex `eltype`")) + @assert T isa AbstractMatrix && Z isa AbstractMatrix && vals isa AbstractVector + @check_size(T, (m, m)) + @check_scalar(T, A) + @check_size(Z, (m, m)) + @check_scalar(Z, A) + @check_size(vals, (n,)) + @check_scalar(vals, A, complex) return nothing end function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals) m, n = size(A) - m == n || throw(ArgumentError("Schur decompsition requires square input matrix")) - size(vals) == (n,) && eltype(vals) == complex(eltype(A)) || - throw(ArgumentError("Eigenvalue vector `vals` must have length equal to size(A, 1) and complex `eltype`")) + m == n || throw(DimensionMismatch("square input matrix expected")) + @assert vals isa AbstractVector + @check_size(vals, (n,)) + @check_scalar(vals, A, complex) return nothing end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 9224e49b..17c621aa 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -11,31 +11,34 @@ copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A) function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ) m, n = size(A) U, S, Vᴴ = USVᴴ - (U isa AbstractMatrix && eltype(U) == eltype(A) && size(U) == (m, m)) || - throw(ArgumentError("`svd_full!` requires square U matrix with equal number of rows and same `eltype` as A")) - (Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (n, n)) || - throw(ArgumentError("`svd_full!` requires square Vᴴ matrix with equal number of columns and same `eltype` as A")) - (S isa AbstractMatrix && eltype(S) == real(eltype(A)) && size(S) == (m, n)) || - throw(ArgumentError("`svd_full!` requires a matrix S of the same size as A with a real `eltype`")) + @assert U isa AbstractMatrix && S isa AbstractMatrix && Vᴴ isa AbstractMatrix + @check_size(U, (m, m)) + @check_scalar(U, A) + @check_size(S, (m, n)) + @check_scalar(S, A, real) + @check_size(Vᴴ, (n, n)) + @check_scalar(Vᴴ, A) return nothing end function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ) m, n = size(A) minmn = min(m, n) U, S, Vᴴ = USVᴴ - (U isa AbstractMatrix && eltype(U) == eltype(A) && size(U) == (m, minmn)) || - throw(ArgumentError("`svd_full!` requires square U matrix with equal number of rows and same `eltype` as A")) - (Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (minmn, n)) || - throw(ArgumentError("`svd_full!` requires square Vᴴ matrix with equal number of columns and same `eltype` as A")) - (S isa Diagonal && eltype(S) == real(eltype(A)) && size(S) == (minmn, minmn)) || - throw(ArgumentError("`svd_compact!` requires Diagonal matrix S with number of rows equal to min(size(A)...) with a real `eltype`")) + @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix + @check_size(U, (m, minmn)) + @check_scalar(U, A) + @check_size(S, (minmn, minmn)) + @check_scalar(S, A, real) + @check_size(Vᴴ, (minmn, n)) + @check_scalar(Vᴴ, A) return nothing end function check_input(::typeof(svd_vals!), A::AbstractMatrix, S) m, n = size(A) minmn = min(m, n) - (S isa AbstractVector && eltype(S) == real(eltype(A)) && size(S) == (minmn,)) || - throw(ArgumentError("`svd_vals!` requires vector S of length min(size(A)...) with a real `eltype`")) + @assert S isa AbstractVector + @check_size(S, (minmn,)) + @check_scalar(S, A, real) return nothing end diff --git a/test/chainrules.jl b/test/chainrules.jl index a31bdc7b..47de1c04 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -325,7 +325,7 @@ end end end -@timedtestset "Orth en null with eltype $T" for T in (Float64, ComplexF64, Float32) +@timedtestset "Orth and null with eltype $T" for T in (Float64, ComplexF64, Float32) rng = StableRNG(12345) m = 19 @testset "size ($m, $n)" for n in (17, m, 23)