Skip to content

Commit f91f42a

Browse files
committed
Make left polar newton more GPU friendly
1 parent 9825144 commit f91f42a

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/implementations/polar.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,19 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
107107
if m > n # initial QR
108108
Q, R = qr_compact!(A)
109109
Rc = view(A, 1:n, 1:n)
110-
copy!(Rc, R)
110+
Rc .= R
111111
Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
112112
else # m == n
113113
R = A
114114
Rc = view(W, 1:n, 1:n)
115-
copy!(Rc, R)
115+
Rc .= R
116116
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
117117
end
118118
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
119119
rmul!(R, γ)
120120
rmul!(Rᴴinv, 1 / γ)
121121
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
122-
copy!(Rc, R)
122+
Rc .= R
123123
i = 1
124124
conv = norm(Rᴴinv, Inf)
125125
while i < maxiter && conv > tol
@@ -128,7 +128,7 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
128128
rmul!(R, γ)
129129
rmul!(Rᴴinv, 1 / γ)
130130
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
131-
copy!(Rc, R)
131+
Rc .= R
132132
conv = norm(Rᴴinv, Inf)
133133
i += 1
134134
end

test/polar.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
1717
TestSuite.seed_rng!(123)
1818
if T BLASFloats
1919
if CUDA.functional()
20-
CUDA_POLAR_ALGS = (PolarViaSVD.((CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()))..., PolarNewton())
20+
# PolarNewton does not work yet on GPU
21+
CUDA_POLAR_ALGS = (PolarViaSVD.((CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()))...,) # PolarNewton())
2122
TestSuite.test_polar(CuMatrix{T}, (m, n), CUDA_POLAR_ALGS)
22-
n == m && TestSuite.test_polar(Diagonal{T, CuVector{T}}, m, (PolarNewton(),))
23+
#n == m && TestSuite.test_polar(Diagonal{T, CuVector{T}}, m, (PolarNewton(),))
2324
end
2425
if AMDGPU.functional()
25-
ROC_POLAR_ALGS = (PolarViaSVD.((ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()))..., PolarNewton())
26+
# PolarNewton does not work yet on GPU
27+
ROC_POLAR_ALGS = (PolarViaSVD.((ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()))...,) # PolarNewton())
2628
TestSuite.test_polar(ROCMatrix{T}, (m, n), ROC_POLAR_ALGS)
27-
n == m && TestSuite.test_polar(Diagonal{T, ROCVector{T}}, m, (PolarNewton(),))
29+
#n == m && TestSuite.test_polar(Diagonal{T, ROCVector{T}}, m, (PolarNewton(),))
2830
end
2931
end
3032
if !is_buildkite

0 commit comments

Comments
 (0)