Skip to content

Commit 0822d37

Browse files
committed
Improve type stability
1 parent 289d424 commit 0822d37

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

src/MatrixAlgebra.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,32 @@ for (svd, svd_trunc, svd_full, svd_compact) in (
7676
(:svd!, :svd_trunc!, :svd_full!, :svd_compact!),
7777
)
7878
@eval begin
79-
function $svd(A::AbstractMatrix; full::Bool = false, trunc = nothing, kwargs...)
80-
return if !isnothing(trunc)
81-
@assert !full "Specified both full and truncation, currently not supported"
82-
$svd_trunc(A; trunc, kwargs...)
83-
else
84-
(full ? $svd_full : $svd_compact)(A; kwargs...)
85-
end
79+
function $svd(
80+
A::AbstractMatrix;
81+
full::Union{Bool, Val{Bool}} = Val(false),
82+
trunc = nothing,
83+
kwargs...,
84+
)
85+
return _svd(full, trunc, A; kwargs...)
86+
end
87+
function _svd(full::Bool, trunc, A::AbstractMatrix; kwargs...)
88+
return _svd(Val(full), trunc, A; kwargs...)
89+
end
90+
function _svd(full::Val{false}, trunc::Nothing, A::AbstractMatrix; kwargs...)
91+
return $svd_compact(A; kwargs...)
92+
end
93+
function _svd(full::Val{false}, trunc, A::AbstractMatrix; kwargs...)
94+
return $svd_trunc(A; trunc, kwargs...)
95+
end
96+
function _svd(full::Val{true}, trunc::Nothing, A::AbstractMatrix; kwargs...)
97+
return $svd_full(A; kwargs...)
98+
end
99+
function _svd(full::Val{true}, trunc, A::AbstractMatrix; kwargs...)
100+
return throw(
101+
ArgumentError(
102+
"Specified both full and truncation, currently not supported"
103+
)
104+
)
86105
end
87106
end
88107
end

test/test_factorizations.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,26 +148,26 @@ end
148148
labels_Vᴴ = (:d, :c)
149149

150150
Acopy = deepcopy(A)
151-
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = true)
151+
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(true))
152152
@test A == Acopy # should not have altered initial array
153153
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
154154
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
155155
@test A A′
156156
@test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary
157157
@test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary
158158

159-
U, S, Vᴴ = svd(A, (2, 1), (4, 3); full = true)
159+
U, S, Vᴴ = svd(A, (2, 1), (4, 3); full = Val(true))
160160
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
161161
@test A contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
162162

163-
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = true)
163+
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = Val(true))
164164
@test A == Acopy # should not have altered initial array
165165
US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v))
166166
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,))
167167
@test A A′
168168
@test size(Vᴴ, 1) == 1
169169

170-
U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = true)
170+
U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = Val(true))
171171
@test A == Acopy # should not have altered initial array
172172
US, labels_US = contract(U, (:u,), S, (:u, :v))
173173
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...))
@@ -182,7 +182,7 @@ end
182182
labels_Vᴴ = (:d, :c)
183183

184184
Acopy = deepcopy(A)
185-
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = false)
185+
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(false))
186186
@test A == Acopy # should not have altered initial array
187187
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
188188
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
@@ -193,14 +193,14 @@ end
193193
Svals = @constinferred svdvals(A, labels_A, labels_U, labels_Vᴴ)
194194
@test Svals diag(S)
195195

196-
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = false)
196+
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = Val(false))
197197
@test A == Acopy # should not have altered initial array
198198
US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v))
199199
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,))
200200
@test A A′
201201
@test size(U, ndims(U)) == 1 == size(Vᴴ, 1)
202202

203-
U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = false)
203+
U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = Val(false))
204204
@test A == Acopy # should not have altered initial array
205205
US, labels_US = contract(U, (:u,), S, (:u, :v))
206206
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...))

0 commit comments

Comments
 (0)