Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export left_polar!, right_polar!
export left_orth, right_orth, left_null, right_null
export left_orth!, right_orth!, left_null!, right_null!

export Native_HouseholderQR, Native_HouseholderLQ
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
Expand Down Expand Up @@ -69,6 +70,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
end

include("common/defaults.jl")
include("common/householder.jl")
include("common/initialization.jl")
include("common/pullbacks.jl")
include("common/safemethods.jl")
Expand Down
142 changes: 142 additions & 0 deletions src/common/householder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
const IndexRange{T <: Integer} = Base.AbstractRange{T}

# Elementary Householder reflection
struct Householder{T, V <: AbstractVector, R <: IndexRange}
β::T
v::V
r::R
end
Base.adjoint(H::Householder) = Householder(conj(H.β), H.v, H.r)

function householder(x::AbstractVector, r::IndexRange = axes(x, 1), k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(x[r], i)
return Householder(β, v, r), ν
end
# Householder reflector h that zeros the elements A[r,col] (except for A[k,col]) upon lmul!(h,A)
function householder(A::AbstractMatrix, r::IndexRange, col::Int, k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(A[r, col], i)
return Householder(β, v, r), ν
end
# Householder reflector that zeros the elements A[row,r] (except for A[row,k]) upon rmul!(A,h')
function householder(A::AbstractMatrix, row::Int, r::IndexRange, k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(conj!(A[row, r]), i)
return Householder(β, v, r), ν
end

# generate Householder vector based on vector v, such that applying the reflection
# to v yields a vector with single non-zero element on position i, whose value is
# positive and thus equal to norm(v)
function _householder!(v::AbstractVector{T}, i::Int = 1) where {T}
β::T = zero(T)
@inbounds begin
σ = abs2(zero(T))
@simd for k in 1:(i - 1)
σ += abs2(v[k])
end
@simd for k in (i + 1):length(v)
σ += abs2(v[k])
end
vi = v[i]
ν = sqrt(abs2(vi) + σ)

if σ == 0 && vi == ν
β = zero(vi)
else
if real(vi) < 0
vi = vi - ν
else
vi = ((vi - conj(vi)) * ν - σ) / (conj(vi) + ν)
end
@simd for k in 1:(i - 1)
v[k] /= vi
end
v[i] = 1
@simd for k in (i + 1):length(v)
v[k] /= vi
end
β = -conj(vi) / (ν)
end
end
return β, v, ν
end

function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
v = H.v
r = H.r
β = H.β
β == 0 && return x
@inbounds begin
μ = conj(zero(v[1])) * zero(x[r[1]])
i = 1
@simd for j in r
μ += conj(v[i]) * x[j]
i += 1
end
μ *= β
i = 1
@simd for j in H.r
x[j] -= μ * v[i]
i += 1
end
end
return x
end
function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2))
v = H.v
r = H.r
β = H.β
β == 0 && return A
@inbounds begin
for k in cols
μ = conj(zero(v[1])) * zero(A[r[1], k])
i = 1
@simd for j in r
μ += conj(v[i]) * A[j, k]
i += 1
end
μ *= β
i = 1
@simd for j in H.r
A[j, k] -= μ * v[i]
i += 1
end
end
end
return A
end
function LinearAlgebra.rmul!(A::AbstractMatrix, H::Householder; rows = axes(A, 1))
v = H.v
r = H.r
β = H.β
β == 0 && return A
w = similar(A, length(rows))
fill!(w, 0)
all(in(axes(A, 2)), r) || error("Householder range r = $r not compatible with matrix A of size $(size(A))")
@inbounds begin
l = 1
for k in r
j = 1
@simd for i in rows
w[j] += A[i, k] * v[l]
j += 1
end
l += 1
end
l = 1
for k in r
j = 1
@simd for i in rows
A[i, k] -= β * w[j] * conj(v[l])
j += 1
end
l += 1
end
end
return A
end
82 changes: 82 additions & 0 deletions src/implementations/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,85 @@ function _diagonal_lq!(
end

_diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool = false) = N

# Native logic
# -------------
function lq_full!(A::AbstractMatrix, LQ, alg::Native_HouseholderLQ)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
A === Q &&
throw(ArgumentError("inplace Q not supported with native LQ implementation"))
_native_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::Native_HouseholderLQ)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
A === Q &&
throw(ArgumentError("inplace Q not supported with native LQ implementation"))
_native_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_null!(A::AbstractMatrix, N, alg::Native_HouseholderLQ)
check_input(lq_null!, A, N, alg)
_native_lq_null!(A, N; alg.kwargs...)
return N
end

function _native_lq!(
A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive::Bool = true # always true regardless of setting
)
m, n = size(A)
minmn = min(m, n)
@inbounds for i in 1:minmn
for j in 1:(i - 1)
L[i, j] = A[i, j]
end
β, v, L[i, i] = _householder!(conj!(view(A, i, i:n)), 1)
for j in (i + 1):size(L, 2)
L[i, j] = 0
end
H = Householder(conj(β), v, i:n)
rmul!(A, H; rows = (i + 1):m)
# A[i, i] == 1; store β instead
A[i, i] = β
end
# copy remaining rows for m > n
@inbounds for j in 1:size(L, 2)
for i in (minmn + 1):m
L[i, j] = A[i, j]
end
end
# build Q
one!(Q)
@inbounds for i in minmn:-1:1
β = A[i, i]
A[i, i] = 1
Hᴴ = Householder(β, view(A, i, i:n), i:n)
rmul!(Q, Hᴴ)
end
return L, Q
end

function _native_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix; positive::Bool = true)
m, n = size(A)
minmn = min(m, n)
@inbounds for i in 1:minmn
β, v, ν = _householder!(conj!(view(A, i, i:n)), 1)
H = Householder(conj(β), v, i:n)
rmul!(A, H; rows = (i + 1):m)
# A[i, i] == 1; store β instead
A[i, i] = β
end
# build Nᴴ
fill!(Nᴴ, zero(eltype(Nᴴ)))
one!(view(Nᴴ, 1:(n - minmn), (minmn + 1):n))
@inbounds for i in minmn:-1:1
β = A[i, i]
A[i, i] = 1
Hᴴ = Householder(β, view(A, i, i:n), i:n)
rmul!(Nᴴ, Hᴴ)
end
return Nᴴ
end
89 changes: 85 additions & 4 deletions src/implementations/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ end

_diagonal_qr_null!(A::AbstractMatrix, N; positive::Bool = false) = N

### GPU logic
# placed here to avoid code duplication since much of the logic is replicable across
# CUDA and AMDGPU
###
# GPU logic
# --------------
# placed here to avoid code duplication since much of the logic is replicable across CUDA and AMDGPU
function MatrixAlgebraKit.qr_full!(
A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}
)
Expand Down Expand Up @@ -321,3 +320,85 @@ function _gpu_qr_null!(
N = _gpu_unmqr!('L', 'N', A, τ, N)
return N
end

# Native logic
# --------------
function qr_full!(A::AbstractMatrix, QR, alg::Native_HouseholderQR)
check_input(qr_full!, A, QR, alg)
Q, R = QR
A === Q &&
throw(ArgumentError("inplace Q not supported with native QR implementation"))
_native_qr!(A, Q, R; alg.kwargs...)
return Q, R
end
function qr_compact!(A::AbstractMatrix, QR, alg::Native_HouseholderQR)
check_input(qr_compact!, A, QR, alg)
Q, R = QR
A === Q &&
throw(ArgumentError("inplace Q not supported with native QR implementation"))
_native_qr!(A, Q, R; alg.kwargs...)
return Q, R
end
function qr_null!(A::AbstractMatrix, N, alg::Native_HouseholderQR)
check_input(qr_null!, A, N, alg)
_native_qr_null!(A, N; alg.kwargs...)
return N
end

function _native_qr!(
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true # always true regardless of setting
)
m, n = size(A)
minmn = min(m, n)
@inbounds for j in 1:minmn
for i in 1:(j - 1)
R[i, j] = A[i, j]
end
β, v, R[j, j] = _householder!(view(A, j:m, j), 1)
for i in (j + 1):size(R, 1)
R[i, j] = 0
end
H = Householder(β, v, j:m)
lmul!(H, A; cols = (j + 1):n)
# A[j,j] == 1; store β instead
A[j, j] = β
end
# copy remaining columns if m < n
@inbounds for j in (minmn + 1):n
for i in 1:size(R, 1)
R[i, j] = A[i, j]
end
end
# build Q
one!(Q)
@inbounds for j in minmn:-1:1
β = A[j, j]
A[j, j] = 1
Hᴴ = Householder(conj(β), view(A, j:m, j), j:m)
lmul!(Hᴴ, Q)
end
return Q, R
end

function _native_qr_null!(A::AbstractMatrix, N::AbstractMatrix; positive::Bool = true)
m, n = size(A)
minmn = min(m, n)
@inbounds for j in 1:minmn
β, v, ν = _householder!(view(A, j:m, j), 1)
H = Householder(β, v, j:m)
lmul!(H, A; cols = (j + 1):n)
# A[j,j] == 1; store β instead
A[j, j] = β
end
# build N
fill!(N, zero(eltype(N)))
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
@inbounds for j in minmn:-1:1
β = A[j, j]
A[j, j] = 1
Hᴴ = Householder(conj(β), view(A, j:m, j), j:m)
lmul!(Hᴴ, N)
end
return N
end
18 changes: 18 additions & 0 deletions src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@

# QR, LQ, QL, RQ Decomposition
# ----------------------------
"""
Native_HouseholderQR()

Algorithm type to denote a native implementation for computing the QR decomposition of
a matrix using Householder reflectors. The diagonal elements of `R` will be non-negative
by construction.
"""
@algdef Native_HouseholderQR

"""
Native_HouseholderLQ()

Algorithm type to denote a native implementation for computing the LQ decomposition of
a matrix using Householder reflectors. The diagonal elements of `L` will be non-negative
by construction.
"""
@algdef Native_HouseholderLQ

"""
LAPACK_HouseholderQR(; blocksize, positive = false, pivoted = false)

Expand Down
3 changes: 3 additions & 0 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)
function default_lq_algorithm(T::Type; kwargs...)
throw(MethodError(default_lq_algorithm, (T,)))
end
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
return Native_HouseholderLQ(; kwargs...)
end
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.BlasMat}
return LAPACK_HouseholderLQ(; kwargs...)
end
Expand Down
3 changes: 3 additions & 0 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)
function default_qr_algorithm(T::Type; kwargs...)
throw(MethodError(default_qr_algorithm, (T,)))
end
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
return Native_HouseholderQR(; kwargs...)
end
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.BlasMat}
return LAPACK_HouseholderQR(; kwargs...)
end
Expand Down
Loading