Skip to content

Commit 3e51ea3

Browse files
committed
GPU-friendly SVD + correct gaugefix
1 parent f95d1b3 commit 3e51ea3

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/implementations/svd.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,18 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
248248
check_input(svd_full!, A, USVᴴ, alg)
249249
Ad = diagview(A)
250250
U, S, Vᴴ = USVᴴ
251-
Sd = diagview(S)
252-
Sd .= abs.(Ad)
253-
p = sortperm(Sd; rev=true)
254-
permute!(Sd, p)
255-
T = eltype(Vᴴ)
251+
p = sortperm(Ad; by=abs, rev=true)
256252
zero!(U)
257253
zero!(Vᴴ)
258-
@inbounds for (i, pi) in enumerate(p)
259-
s = Ad[pi]
260-
U[pi, i] = sign_safe(s)
261-
Vᴴ[i, pi] = one(T)
254+
T = eltype(U)
255+
U[CartesianIndex.(reverse.(enumerate(p)))] .= Ref(one(T))
256+
Vᴴ[CartesianIndex.(enumerate(p))] .= sign_safe.(view(Ad, p))
257+
Sd = diagview(S)
258+
if Ad === Sd
259+
@. Sd = abs(Ad)
260+
permute!(Sd, p)
261+
else
262+
Sd .= abs.(view(Ad, p))
262263
end
263264
return U, S, Vᴴ
264265
end

0 commit comments

Comments
 (0)