Skip to content

Commit 6e119a5

Browse files
committed
GPU-friendly SVD + correct gaugefix
1 parent f95d1b3 commit 6e119a5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/implementations/svd.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,17 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
248248
check_input(svd_full!, A, USVᴴ, alg)
249249
Ad = diagview(A)
250250
U, S, Vᴴ = USVᴴ
251-
Sd = diagview(S)
252-
Sd .= abs.(Ad)
253-
p = sortperm(Sd; rev=true)
254-
permute!(Sd, p)
255-
T = eltype(Vᴴ)
251+
p = sortperm(Ad; by=abs, rev=true)
256252
zero!(U)
257253
zero!(Vᴴ)
258-
@inbounds for (i, pi) in enumerate(p)
259-
s = Ad[pi]
260-
U[pi, i] = sign_safe(s)
261-
Vᴴ[i, pi] = one(T)
254+
U[CartesianIndex.(enumerate(p))] .= Ref(one(T))
255+
Vᴴ[CartesianIndex.(reverse.(enumerate(p)))] .= sign_safe.(view(Ad, p))
256+
Sd = diagview(S)
257+
if Ad === Sd
258+
@. Sd = abs(Ad)
259+
permute!(Sd, p)
260+
else
261+
Sd .= abs.(view(Ad, p))
262262
end
263263
return U, S, Vᴴ
264264
end

0 commit comments

Comments
 (0)