Skip to content

Commit 47cef6e

Browse files
Jutholkdvos
andauthored
Jh/checkmacros (#13)
* add check macros * fix typo * another typo * Apply suggestions from code review Co-authored-by: Lukas Devos <[email protected]> --------- Co-authored-by: Lukas Devos <[email protected]>
1 parent 13e6245 commit 47cef6e

File tree

11 files changed

+152
-102
lines changed

11 files changed

+152
-102
lines changed

src/algorithms.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,39 @@ macro functiondef(f)
143143
Core.@__doc__ $f, $f!
144144
end)
145145
end
146+
147+
"""
148+
@check_scalar(x, y, [op], [eltype])
149+
150+
Check if `eltype(x) == op(eltype(y))` and throw an error if not.
151+
By default `op = identity` and `eltype = eltype'.
152+
"""
153+
macro check_scalar(x, y, op=:identity, eltype=:eltype)
154+
error_message = "Unexpected scalar type: "
155+
error_message *= string(eltype) * "(" * string(x) * ")"
156+
if op == :identity
157+
error_message *= " != " * string(eltype) * "(" * string(y) * ")"
158+
else
159+
error_message *= " != " * string(op) * "(" * string(eltype) * "(" * string(y) * "))"
160+
end
161+
return esc(quote
162+
$eltype($x) == $op($eltype($y)) || throw(ArgumentError($error_message))
163+
end)
164+
end
165+
166+
"""
167+
@check_size(x, sz, [size])
168+
169+
Check if `size(x) == sz` and throw an error if not.
170+
By default, `size = size`.
171+
"""
172+
macro check_size(x, sz, size=:size)
173+
msgstart = string(size) * "(" * string(x) * ") = "
174+
err = gensym()
175+
return esc(quote
176+
szx = $size($x)
177+
$err = $msgstart * string(szx) * " instead of expected value " *
178+
string($sz)
179+
szx == $sz || throw(DimensionMismatch($err))
180+
end)
181+
end

src/implementations/decompositions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010
# QR, LQ, QL, RQ Decomposition
1111
# ----------------------------
1212
"""
13-
LAPACK_HoudeholderQR(; blocksize, positive = false, pivoted = false)
13+
LAPACK_HouseholderQR(; blocksize, positive = false, pivoted = false)
1414
1515
Algorithm type to denote the standard LAPACK algorithm for computing the QR decomposition of
1616
a matrix using Householder reflectors. The specific LAPACK function can be controlled using
1717
the keyword arugments, i.e. `?geqrt` will be chosen if `blocksize > 1`. With
1818
`blocksize == 1`, `?geqrf` will be chosen if `pivoted == false` and `?geqp3` will be chosen
19-
if `pivoted == true`. The keyword `positive =true` can be used to ensure that the diagonal
19+
if `pivoted == true`. The keyword `positive=true` can be used to ensure that the diagonal
2020
elements of `R` are non-negative.
2121
"""
2222
@algdef LAPACK_HouseholderQR
2323

2424
"""
25-
LAPACK_HoudeholderLQ(; blocksize, positive = false)
25+
LAPACK_HouseholderLQ(; blocksize, positive = false)
2626
2727
Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of
2828
a matrix using Householder reflectors. The specific LAPACK function can be controlled using

src/implementations/eig.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@ copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)
1010

1111
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
1212
m, n = size(A)
13-
m == n || throw(ArgumentError("Eigenvalue decomposition requires square input matrix"))
13+
m == n || throw(DimensionMismatch("square input matrix expected"))
1414
D, V = DV
15-
Tc = complex(eltype(A))
16-
(V isa AbstractMatrix && eltype(V) == Tc && size(V) == (m, m)) ||
17-
throw(ArgumentError("`eig_full!` requires square matrix V with same size as A and complex `eltype`"))
18-
(D isa Diagonal && eltype(D) == Tc && size(D) == (m, m)) ||
19-
throw(ArgumentError("`eig_full!` requires Diagonal matrix D with same size as A and complex `eltype`"))
15+
@assert D isa Diagonal && V isa AbstractMatrix
16+
@check_size(D, (m, m))
17+
@check_scalar(D, A, complex)
18+
@check_size(V, (m, m))
19+
@check_scalar(V, A, complex)
2020
return nothing
2121
end
2222
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D)
2323
m, n = size(A)
24-
m == n || throw(ArgumentError("Eigenvalue decomposition requires square input matrix"))
25-
Tc = complex(eltype(A))
26-
size(D) == (n,) && eltype(D) == Tc ||
27-
throw(ArgumentError("Eigenvalue vector `D` must have length equal to size(A, 1) and complex `eltype`"))
24+
m == n || throw(DimensionMismatch("square input matrix expected"))
25+
@assert D isa AbstractVector
26+
@check_size(D, (n,))
27+
@check_scalar(D, A, complex)
2828
return nothing
2929
end
3030

src/implementations/eigh.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
1010

1111
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
1212
m, n = size(A)
13-
m == n || throw(ArgumentError("Eigenvalue decompsition requires square input matrix"))
13+
m == n || throw(DimensionMismatch("square input matrix expected"))
1414
D, V = DV
15-
(V isa AbstractMatrix && eltype(V) == eltype(A) && size(V) == (m, m)) ||
16-
throw(ArgumentError("`eigh_full!` requires square V matrix with same size and `eltype` as A"))
17-
(D isa Diagonal && eltype(D) == real(eltype(A)) && size(D) == (m, m)) ||
18-
throw(ArgumentError("`eigh_full!` requires Diagonal matrix D with same size as A with a real `eltype`"))
15+
@assert D isa Diagonal && V isa AbstractMatrix
16+
@check_size(D, (m, m))
17+
@check_scalar(D, A, real)
18+
@check_size(V, (m, m))
19+
@check_scalar(V, A)
1920
return nothing
2021
end
2122
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D)
2223
m, n = size(A)
23-
m == n || throw(ArgumentError("Eigenvalue decompsition requires square input matrix"))
24-
(size(D) == (n,) && eltype(D) == real(eltype(A))) ||
25-
throw(ArgumentError("Eigenvalue vector `D` must have length equal to size(A, 1) with a real `eltype`"))
24+
@assert D isa AbstractVector
25+
@check_size(D, (n,))
26+
@check_scalar(D, A, real)
2627
return nothing
2728
end
2829

src/implementations/lq.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,30 @@ end
1313
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
1414
m, n = size(A)
1515
L, Q = LQ
16-
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (n, n)) ||
17-
throw(DimensionMismatch("Full unitary matrix Q must be square with equal number of columns as A"))
18-
(L isa AbstractMatrix && eltype(L) == eltype(A) && (isempty(L) || size(L) == (m, n))) ||
19-
throw(DimensionMismatch("Lower triangular matrix L must have size equal to A"))
16+
@assert L isa AbstractMatrix && Q isa AbstractMatrix
17+
isempty(L) || @check_size(L, (m, n))
18+
@check_scalar(L, A)
19+
@check_size(Q, (n, n))
20+
@check_scalar(Q, A)
2021
return nothing
2122
end
2223
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
2324
m, n = size(A)
24-
if m <= n
25-
L, Q = LQ
26-
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, n)) ||
27-
throw(DimensionMismatch("Isometric Q must have size equal to A"))
28-
(L isa AbstractMatrix && eltype(L) == eltype(A) &&
29-
(isempty(L) || size(L) == (m, m))) ||
30-
throw(DimensionMismatch("Lower triangular matrix L must be square with equal number of columns as A"))
31-
else
32-
check_input(lq_full!, A, LQ)
33-
end
25+
minmn = min(m, n)
26+
L, Q = LQ
27+
@assert L isa AbstractMatrix && Q isa AbstractMatrix
28+
isempty(L) || @check_size(L, (m, minmn))
29+
@check_scalar(L, A)
30+
@check_size(Q, (minmn, n))
31+
@check_scalar(Q, A)
32+
return nothing
3433
end
3534
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ)
3635
m, n = size(A)
3736
minmn = min(m, n)
38-
(Nᴴ isa AbstractMatrix && eltype(Nᴴ) == eltype(A) && size(Nᴴ) == (n - minmn, n)) ||
39-
throw(DimensionMismatch("Matrix Nᴴ must have a the same eltype as A and a size such that [A; Nᴴ] is square"))
37+
@assert Nᴴ isa AbstractMatrix
38+
@check_size(Nᴴ, (n - minmn, n))
39+
@check_scalar(Nᴴ, A)
4040
return nothing
4141
end
4242

src/implementations/orthnull.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,43 @@ function check_input(::typeof(left_orth!), A::AbstractMatrix, VC)
99
m, n = size(A)
1010
minmn = min(m, n)
1111
V, C = VC
12-
(V isa AbstractMatrix && eltype(V) == eltype(A) && size(V) == (m, minmn)) ||
13-
throw(DimensionMismatch("Isometric V must have the same eltype as A, the same number of rows and min(m, n) columns"))
14-
(C isa AbstractMatrix && eltype(C) == eltype(A) &&
15-
(isempty(C) || size(C) == (minmn, n))) ||
16-
throw(DimensionMismatch("Corestriction C must have the same eltype as A, the same number of columns and min(m, n) rows"))
12+
@assert V isa AbstractMatrix && C isa AbstractMatrix
13+
@check_size(V, (m, minmn))
14+
@check_scalar(V, A)
15+
if !isempty(C)
16+
@check_size(C, (minmn, n))
17+
@check_scalar(C, A)
18+
end
1719
return nothing
1820
end
1921
function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ)
2022
m, n = size(A)
2123
minmn = min(m, n)
2224
C, Vᴴ = CVᴴ
23-
(Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (minmn, n)) ||
24-
throw(DimensionMismatch("Adjoint isometric matrix Vᴴ must have the same eltype as A, the same number of columns and min(m, n) rows"))
25-
(C isa AbstractMatrix && eltype(C) == eltype(A) &&
26-
(isempty(C) || size(C) == (m, minmn))) ||
27-
throw(DimensionMismatch("Corestriction C must have the same eltype as A, the same number of rows and min(m, n) columns"))
25+
@assert C isa AbstractMatrix && Vᴴ isa AbstractMatrix
26+
if !isempty(C)
27+
@check_size(C, (m, minmn))
28+
@check_scalar(C, A)
29+
end
30+
@check_size(Vᴴ, (minmn, n))
31+
@check_scalar(Vᴴ, A)
2832
return nothing
2933
end
3034

3135
function check_input(::typeof(left_null!), A::AbstractMatrix, N)
3236
m, n = size(A)
3337
minmn = min(m, n)
34-
(N isa AbstractMatrix && eltype(N) == eltype(A) && size(N) == (m, m - minmn)) ||
35-
throw(DimensionMismatch("Isometric matrix must have the same eltype as A, the same number of rows and m - min(m, n) columns"))
38+
@assert N isa AbstractMatrix
39+
@check_size(N, (m, m - minmn))
40+
@check_scalar(N, A)
3641
return nothing
3742
end
3843
function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ)
3944
m, n = size(A)
4045
minmn = min(m, n)
41-
(Nᴴ isa AbstractMatrix && eltype(Nᴴ) == eltype(A) && size(Nᴴ) == (n - minmn, n)) ||
42-
throw(DimensionMismatch("Adjoint isometric matrix Nᴴ must have the same eltype as A, the same number of columns and n - min(m, n) rows"))
46+
@assert Nᴴ isa AbstractMatrix
47+
@check_size(Nᴴ, (n - minmn, n))
48+
@check_scalar(Nᴴ, A)
4349
return nothing
4450
end
4551

src/implementations/polar.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,24 @@ function check_input(::typeof(left_polar!), A::AbstractMatrix, WP)
77
m, n = size(A)
88
W, P = WP
99
m >= n ||
10-
throw(ArgumentError("`left_polar!` requires a matrix A with at least as many rows as columns"))
11-
(W isa AbstractMatrix && eltype(W) == eltype(A) && size(W) == (m, n)) ||
12-
throw(ArgumentError("`left_polar!` requires a matrix W with the same size and eltype as A"))
13-
(P isa AbstractMatrix && eltype(P) == eltype(A) && size(P) == (n, n)) ||
14-
throw(ArgumentError("`left_polar!` requires a square matrix P with the same eltype and number of columns as A"))
10+
throw(ArgumentError("input matrix needs at least as many rows as columns"))
11+
@assert W isa AbstractMatrix && P isa AbstractMatrix
12+
@check_size(W, (m, n))
13+
@check_scalar(W, A)
14+
@check_size(P, (n, n))
15+
@check_scalar(P, A)
1516
return nothing
1617
end
1718
function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ)
1819
m, n = size(A)
1920
P, Wᴴ = PWᴴ
2021
n >= m ||
21-
throw(ArgumentError("`right_polar!` requires a matrix A with at least as many columns as rows"))
22-
(P isa AbstractMatrix && eltype(P) == eltype(A) && size(P) == (m, m)) ||
23-
throw(ArgumentError("`right_polar!` requires a square matrix P with the same eltype and number of rows as A"))
24-
(Wᴴ isa AbstractMatrix && eltype(Wᴴ) == eltype(A) && size(Wᴴ) == (m, n)) ||
25-
throw(ArgumentError("`right_polar!` requires a matrix Wᴴ with the same size and eltype as A"))
22+
throw(ArgumentError("input matrix needs at least as many columns as rows"))
23+
@assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
24+
@check_size(P, (m, m))
25+
@check_scalar(P, A)
26+
@check_size(Wᴴ, (m, n))
27+
@check_scalar(Wᴴ, A)
2628
return nothing
2729
end
2830

src/implementations/qr.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,30 @@ end
1313
function check_input(::typeof(qr_full!), A::AbstractMatrix, QR)
1414
m, n = size(A)
1515
Q, R = QR
16-
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, m)) ||
17-
throw(DimensionMismatch("Full unitary matrix Q must be square with equal number of rows as A"))
18-
(R isa AbstractMatrix && eltype(R) == eltype(A) && (isempty(R) || size(R) == (m, n))) ||
19-
throw(DimensionMismatch("Upper triangular matrix R must have size equal to A"))
16+
@assert Q isa AbstractMatrix && R isa AbstractMatrix
17+
@check_size(Q, (m, m))
18+
@check_scalar(Q, A)
19+
isempty(R) || @check_size(R, (m, n))
20+
@check_scalar(R, A)
2021
return nothing
2122
end
2223
function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR)
2324
m, n = size(A)
24-
if n <= m
25-
Q, R = QR
26-
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, n)) ||
27-
throw(DimensionMismatch("Isometric Q must have size equal to A"))
28-
(R isa AbstractMatrix && eltype(R) == eltype(A) &&
29-
(isempty(R) || size(R) == (n, n))) ||
30-
throw(DimensionMismatch("Upper triangular matrix R must be square with equal number of columns as A"))
31-
else
32-
check_input(qr_full!, A, QR)
33-
end
25+
minmn = min(m, n)
26+
Q, R = QR
27+
@assert Q isa AbstractMatrix && R isa AbstractMatrix
28+
@check_size(Q, (m, minmn))
29+
@check_scalar(Q, A)
30+
isempty(R) || @check_size(R, (minmn, n))
31+
@check_scalar(R, A)
32+
return nothing
3433
end
3534
function check_input(::typeof(qr_null!), A::AbstractMatrix, N)
3635
m, n = size(A)
3736
minmn = min(m, n)
38-
(N isa AbstractMatrix && eltype(N) == eltype(A) && size(N) == (m, m - minmn)) ||
39-
throw(DimensionMismatch("Matrix N must have a the same eltype as A and a size such that [A N] is square"))
37+
@assert N isa AbstractMatrix
38+
@check_size(N, (m, m - minmn))
39+
@check_scalar(N, A)
4040
return nothing
4141
end
4242

src/implementations/schur.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,23 @@ copy_input(::typeof(schur_vals), A::AbstractMatrix) = copy_input(eig_vals, A)
66
# check input
77
function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv)
88
m, n = size(A)
9-
m == n || throw(ArgumentError("Schur decompsition requires square input matrix"))
9+
m == n || throw(DimensionMismatch("square input matrix expected"))
1010
T, Z, vals = TZv
11-
(Z isa AbstractMatrix && eltype(Z) == eltype(A) && size(Z) == (m, m)) ||
12-
throw(ArgumentError("`schur_full!` requires square Z matrix with same size and `eltype` as A"))
13-
(T isa AbstractMatrix && eltype(T) == eltype(A) && size(T) == (m, m)) ||
14-
throw(ArgumentError("`schur_full!` requires square T matrix with same size and `eltype` as A"))
15-
size(vals) == (n,) && eltype(vals) == complex(eltype(A)) ||
16-
throw(ArgumentError("Eigenvalue vector `vals` must have length equal to size(A, 1) and complex `eltype`"))
11+
@assert T isa AbstractMatrix && Z isa AbstractMatrix && vals isa AbstractVector
12+
@check_size(T, (m, m))
13+
@check_scalar(T, A)
14+
@check_size(Z, (m, m))
15+
@check_scalar(Z, A)
16+
@check_size(vals, (n,))
17+
@check_scalar(vals, A, complex)
1718
return nothing
1819
end
1920
function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals)
2021
m, n = size(A)
21-
m == n || throw(ArgumentError("Schur decompsition requires square input matrix"))
22-
size(vals) == (n,) && eltype(vals) == complex(eltype(A)) ||
23-
throw(ArgumentError("Eigenvalue vector `vals` must have length equal to size(A, 1) and complex `eltype`"))
22+
m == n || throw(DimensionMismatch("square input matrix expected"))
23+
@assert vals isa AbstractVector
24+
@check_size(vals, (n,))
25+
@check_scalar(vals, A, complex)
2426
return nothing
2527
end
2628

src/implementations/svd.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,34 @@ copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A)
1111
function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ)
1212
m, n = size(A)
1313
U, S, Vᴴ = USVᴴ
14-
(U isa AbstractMatrix && eltype(U) == eltype(A) && size(U) == (m, m)) ||
15-
throw(ArgumentError("`svd_full!` requires square U matrix with equal number of rows and same `eltype` as A"))
16-
(Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (n, n)) ||
17-
throw(ArgumentError("`svd_full!` requires square Vᴴ matrix with equal number of columns and same `eltype` as A"))
18-
(S isa AbstractMatrix && eltype(S) == real(eltype(A)) && size(S) == (m, n)) ||
19-
throw(ArgumentError("`svd_full!` requires a matrix S of the same size as A with a real `eltype`"))
14+
@assert U isa AbstractMatrix && S isa AbstractMatrix && Vᴴ isa AbstractMatrix
15+
@check_size(U, (m, m))
16+
@check_scalar(U, A)
17+
@check_size(S, (m, n))
18+
@check_scalar(S, A, real)
19+
@check_size(Vᴴ, (n, n))
20+
@check_scalar(Vᴴ, A)
2021
return nothing
2122
end
2223
function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ)
2324
m, n = size(A)
2425
minmn = min(m, n)
2526
U, S, Vᴴ = USVᴴ
26-
(U isa AbstractMatrix && eltype(U) == eltype(A) && size(U) == (m, minmn)) ||
27-
throw(ArgumentError("`svd_full!` requires square U matrix with equal number of rows and same `eltype` as A"))
28-
(Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (minmn, n)) ||
29-
throw(ArgumentError("`svd_full!` requires square Vᴴ matrix with equal number of columns and same `eltype` as A"))
30-
(S isa Diagonal && eltype(S) == real(eltype(A)) && size(S) == (minmn, minmn)) ||
31-
throw(ArgumentError("`svd_compact!` requires Diagonal matrix S with number of rows equal to min(size(A)...) with a real `eltype`"))
27+
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
28+
@check_size(U, (m, minmn))
29+
@check_scalar(U, A)
30+
@check_size(S, (minmn, minmn))
31+
@check_scalar(S, A, real)
32+
@check_size(Vᴴ, (minmn, n))
33+
@check_scalar(Vᴴ, A)
3234
return nothing
3335
end
3436
function check_input(::typeof(svd_vals!), A::AbstractMatrix, S)
3537
m, n = size(A)
3638
minmn = min(m, n)
37-
(S isa AbstractVector && eltype(S) == real(eltype(A)) && size(S) == (minmn,)) ||
38-
throw(ArgumentError("`svd_vals!` requires vector S of length min(size(A)...) with a real `eltype`"))
39+
@assert S isa AbstractVector
40+
@check_size(S, (minmn,))
41+
@check_scalar(S, A, real)
3942
return nothing
4043
end
4144

0 commit comments

Comments
 (0)