Skip to content

Commit f95d1b3

Browse files
committed
GPU-friendly QR/LQ
1 parent 4add253 commit f95d1b3

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

src/implementations/lq.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ end
8383

8484
for f! in (:lq_full!, :lq_compact!)
8585
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm)
86-
return A, similar(A)
86+
return similar(A), A
8787
end
8888
end
8989

@@ -253,17 +253,15 @@ end
253253
# --------------
254254
function _diagonal_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
255255
positive::Bool=false)
256+
# note: Ad and Qd might share memory here so order of operations is important
256257
Ad = diagview(A)
257258
Ld = diagview(L)
258259
Qd = diagview(Q)
259260
if positive
260-
@inbounds @simd for i in eachindex(Ad)
261-
s = sign_safe(Ad[i])
262-
Qd[i] = s
263-
Ld[i] = conj(s) * Ad[i]
264-
end
261+
@. Ld = abs(Ad)
262+
@. Qd = sign_safe(Ad)
265263
else
266-
A === L || copy!(Ld, Ad)
264+
Ld .= Ad
267265
one!(Q)
268266
end
269267
return L, Q

src/implementations/qr.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ end
8383

8484
for f! in (:qr_full!, :qr_compact!)
8585
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm)
86-
return similar(A), A
86+
return A, similar(A)
8787
end
8888
end
8989

@@ -216,17 +216,15 @@ end
216216
# --------------
217217
function _diagonal_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
218218
positive::Bool=false)
219+
# note: Ad and Qd might share memory here so order of operations is important
219220
Ad = diagview(A)
220221
Qd = diagview(Q)
221222
Rd = diagview(R)
222223
if positive
223-
@inbounds @simd for i in eachindex(Ad)
224-
s = sign_safe(Ad[i])
225-
Qd[i] = s
226-
Rd[i] = conj(s) * Ad[i]
227-
end
224+
@. Rd = abs(Ad)
225+
@. Qd = sign_safe(Ad)
228226
else
229-
A === R || copy!(Rd, Ad)
227+
Rd .= Ad
230228
one!(Q)
231229
end
232230
return Q, R

0 commit comments

Comments
 (0)