Skip to content

Commit 55a70b9

Browse files
Juthokshyatt
authored andcommitted
native_qr
1 parent 91a8a69 commit 55a70b9

File tree

4 files changed

+235
-4
lines changed

4 files changed

+235
-4
lines changed

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
7272
end
7373

7474
include("common/defaults.jl")
75+
include("common/householder.jl")
7576
include("common/initialization.jl")
7677
include("common/pullbacks.jl")
7778
include("common/safemethods.jl")

src/common/householder.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
const IndexRange{T <: Integer} = Base.AbstractRange{T}
2+
3+
# Elementary Householder reflection
4+
struct Householder{T, V <: AbstractVector, R <: IndexRange}
5+
β::T
6+
v::V
7+
r::R
8+
end
9+
Base.adjoint(H::Householder) = Householder(conj(H.β), H.v, H.r)
10+
11+
function householder(x::AbstractVector, r::IndexRange = axes(x, 1), k = first(r))
12+
i = findfirst(equalto(k), r)
13+
i == nothing && error("k = $k should be in the range r = $r")
14+
β, v, ν = _householder!(x[r], i)
15+
return Householder(β, v, r), ν
16+
end
17+
# Householder reflector h that zeros the elements A[r,col] (except for A[k,col]) upon lmul!(h,A)
18+
function householder(A::AbstractMatrix, r::IndexRange, col::Int, k = first(r))
19+
i = findfirst(equalto(k), r)
20+
i == nothing && error("k = $k should be in the range r = $r")
21+
β, v, ν = _householder!(A[r, col], i)
22+
return Householder(β, v, r), ν
23+
end
24+
# Householder reflector that zeros the elements A[row,r] (except for A[row,k]) upon rmul!(A,h')
25+
function householder(A::AbstractMatrix, row::Int, r::IndexRange, k = first(r))
26+
i = findfirst(equalto(k), r)
27+
i == nothing && error("k = $k should be in the range r = $r")
28+
β, v, ν = _householder!(conj!(A[row, r]), i)
29+
return Householder(β, v, r), ν
30+
end
31+
32+
# generate Householder vector based on vector v, such that applying the reflection
33+
# to v yields a vector with single non-zero element on position i, whose value is
34+
# positive and thus equal to norm(v)
35+
function _householder!(v::AbstractVector{T}, i::Int = 1) where {T}
36+
β::T = zero(T)
37+
@inbounds begin
38+
σ = abs2(zero(T))
39+
@simd for k in 1:(i - 1)
40+
σ += abs2(v[k])
41+
end
42+
@simd for k in (i + 1):length(v)
43+
σ += abs2(v[k])
44+
end
45+
vi = v[i]
46+
ν = sqrt(abs2(vi) + σ)
47+
48+
if σ == 0 && vi == ν
49+
β = zero(vi)
50+
else
51+
if real(vi) < 0
52+
vi = vi - ν
53+
else
54+
vi = ((vi - conj(vi)) * ν - σ) / (conj(vi) + ν)
55+
end
56+
@simd for k in 1:(i - 1)
57+
v[k] /= vi
58+
end
59+
v[i] = 1
60+
@simd for k in (i + 1):length(v)
61+
v[k] /= vi
62+
end
63+
β = -conj(vi) / (ν)
64+
end
65+
end
66+
return β, v, ν
67+
end
68+
69+
function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
70+
v = H.v
71+
r = H.r
72+
β = H.β
73+
β == 0 && return x
74+
@inbounds begin
75+
μ = conj(zero(v[1])) * zero(x[r[1]])
76+
i = 1
77+
@simd for j in r
78+
μ += conj(v[i]) * x[j]
79+
i += 1
80+
end
81+
μ *= β
82+
i = 1
83+
@simd for j in H.r
84+
x[j] -= μ * v[i]
85+
i += 1
86+
end
87+
end
88+
return x
89+
end
90+
function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2))
91+
v = H.v
92+
r = H.r
93+
β = H.β
94+
β == 0 && return A
95+
@inbounds begin
96+
for k in cols
97+
μ = conj(zero(v[1])) * zero(A[r[1], k])
98+
i = 1
99+
@simd for j in r
100+
μ += conj(v[i]) * A[j, k]
101+
i += 1
102+
end
103+
μ *= β
104+
i = 1
105+
@simd for j in H.r
106+
A[j, k] -= μ * v[i]
107+
i += 1
108+
end
109+
end
110+
end
111+
return A
112+
end
113+
function LinearAlgebra.rmul!(A::AbstractMatrix, H::Householder; rows = axes(A, 1))
114+
v = H.v
115+
r = H.r
116+
β = H.β
117+
β == 0 && return A
118+
w = similar(A, length(rows))
119+
fill!(w, 0)
120+
all(in(axes(A, 2)), r) || error("Householder range r = $r not compatible with matrix A of size $(size(A))")
121+
@inbounds begin
122+
l = 1
123+
for k in r
124+
j = 1
125+
@simd for i in rows
126+
w[j] += A[i, k] * v[l]
127+
j += 1
128+
end
129+
l += 1
130+
end
131+
l = 1
132+
for k in r
133+
j = 1
134+
@simd for i in rows
135+
A[i, k] -= β * w[j] * conj(v[l])
136+
j += 1
137+
end
138+
l += 1
139+
end
140+
end
141+
return A
142+
end

src/implementations/qr.jl

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ end
233233

234234
_diagonal_qr_null!(A::AbstractMatrix, N; positive::Bool = false) = N
235235

236-
### GPU logic
237-
# placed here to avoid code duplication since much of the logic is replicable across
238-
# CUDA and AMDGPU
239-
###
236+
# GPU logic
237+
# --------------
238+
# placed here to avoid code duplication since much of the logic is replicable across CUDA and AMDGPU
240239
function MatrixAlgebraKit.qr_full!(
241240
A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}
242241
)
@@ -325,3 +324,83 @@ function _gpu_qr_null!(
325324
N = _gpu_unmqr!('L', 'N', A, τ, N)
326325
return N
327326
end
327+
328+
# Native logic
329+
# --------------
330+
function qr_full!(A::AbstractMatrix, QR, alg::Native_HouseholderQR)
331+
check_input(qr_full!, A, QR, alg)
332+
Q, R = QR
333+
A === Q &&
334+
throw(ArgumentError("inplace Q not supported with native QR implementation"))
335+
_native_qr!(A, Q, R; alg.kwargs...)
336+
return Q, R
337+
end
338+
function qr_compact!(A::AbstractMatrix, QR, alg::Native_HouseholderQR)
339+
check_input(qr_compact!, A, QR, alg)
340+
Q, R = QR
341+
A === Q &&
342+
throw(ArgumentError("inplace Q not supported with native QR implementation"))
343+
_native_qr!(A, Q, R; alg.kwargs...)
344+
return Q, R
345+
end
346+
function qr_null!(A::AbstractMatrix, N, alg::Native_HouseholderQR)
347+
check_input(qr_null!, A, N, alg)
348+
_native_qr_null!(A, N; alg.kwargs...)
349+
return N
350+
end
351+
352+
function _native_qr!(
353+
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix
354+
)
355+
m, n = size(A)
356+
minmn = min(m, n)
357+
@inbounds for j in 1:minmn
358+
for i in 1:(j - 1)
359+
R[i, j] = A[i, j]
360+
end
361+
β, v, R[j, j] = _householder!(view(A, j:m, j), 1)
362+
for i in (j + 1):size(R, 1)
363+
R[i, j] = 0
364+
end
365+
H = Householder(β, v, j:m)
366+
lmul!(H, A; cols = (j + 1):n)
367+
# A[j,j] == 1; store β instead
368+
A[j, j] = β
369+
end
370+
@inbounds for j in (minmn + 1):n
371+
for i in 1:size(R, 1)
372+
R[i, j] = A[i, j]
373+
end
374+
end
375+
# build Q
376+
one!(Q)
377+
for j in minmn:-1:1
378+
β = A[j, j]
379+
A[j, j] = 1
380+
Hᴴ = Householder(conj(β), view(A, j:m, j), j:m)
381+
lmul!(Hᴴ, Q)
382+
end
383+
return Q, R
384+
end
385+
386+
function _native_qr_null!(A::AbstractMatrix, N::AbstractMatrix)
387+
m, n = size(A)
388+
minmn = min(m, n)
389+
@inbounds for j in 1:minmn
390+
β, v, ν = _householder!(view(A, j:m, j), 1)
391+
H = Householder(β, v, j:m)
392+
lmul!(H, A; cols = (j + 1):n)
393+
# A[j,j] == 1; store β instead
394+
A[j, j] = β
395+
end
396+
# build Q
397+
fill!(N, zero(eltype(N)))
398+
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
399+
for j in minmn:-1:1
400+
β = A[j, j]
401+
A[j, j] = 1
402+
Hᴴ = Householder(conj(β), view(A, j:m, j), j:m)
403+
lmul!(Hᴴ, N)
404+
end
405+
return N
406+
end

src/interface/decompositions.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99

1010
# QR, LQ, QL, RQ Decomposition
1111
# ----------------------------
12+
"""
13+
Native_HouseholderQR(; blocksize, positive = false, pivoted = false)
14+
15+
Algorithm type to denote a native implementation for computing the QR decomposition of
16+
a matrix using Householder reflectors, .The diagonal elements of `R` will be non-negative
17+
by construction.
18+
"""
19+
@algdef Native_HouseholderQR
20+
1221
"""
1322
LAPACK_HouseholderQR(; blocksize, positive = false, pivoted = false)
1423

0 commit comments

Comments
 (0)