Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
93f37f5
add GenericExt
sanderdemeyer Oct 29, 2025
5e08f7a
include weak dependencies
sanderdemeyer Oct 29, 2025
92534bc
Update ext/MatrixAlgebraKitGenericExt.jl
sanderdemeyer Oct 30, 2025
ddaf055
Update ext/MatrixAlgebraKitGenericExt.jl
sanderdemeyer Oct 30, 2025
88b9902
comments from Lukas
sanderdemeyer Oct 30, 2025
3eee758
remove `BigFloat_LQ_Householder`
sanderdemeyer Oct 30, 2025
f8d331e
Change struct names and relax type restrictions
sanderdemeyer Oct 30, 2025
03caa56
Split extensions
sanderdemeyer Oct 30, 2025
faa0921
fix copies
sanderdemeyer Oct 31, 2025
9ed02bd
fix type instability
sanderdemeyer Oct 31, 2025
c3e069e
Name change
sanderdemeyer Oct 31, 2025
7af4d52
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
c6ba4e4
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
fdb1222
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
49404bf
Merge branch 'QuantumKitHub:main' into GenericLinearAlgebraExt
sanderdemeyer Nov 4, 2025
837a7a4
Resolve comments
sanderdemeyer Nov 4, 2025
78f92f7
change folder names
sanderdemeyer Nov 4, 2025
fb1994f
add docs and namechange
sanderdemeyer Nov 4, 2025
3a36446
Remove unnecessary type parameters
lkdvos Nov 4, 2025
e6f64b2
simplify SVD and remove allocations
lkdvos Nov 4, 2025
4323b17
simplify Eigh and remove allocations
lkdvos Nov 4, 2025
f99112b
simplify QR
lkdvos Nov 4, 2025
89fd42b
simplify Eig and remove allocations
lkdvos Nov 4, 2025
35af44d
switch to `lmul!`
lkdvos Nov 5, 2025
4fe0fe1
docstring improvements
lkdvos Nov 5, 2025
c08dad4
move Diagonal tests
lkdvos Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"

[extensions]
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
MatrixAlgebraKitCUDAExt = "CUDA"
MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra"
MatrixAlgebraKitGenericSchurExt = "GenericSchur"

[compat]
AMDGPU = "2"
Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
CUDA = "5"
GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
LinearAlgebra = "1"
SafeTestsets = "0.1"
Expand All @@ -42,4 +48,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
161 changes: 161 additions & 0 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input
using GenericLinearAlgebra: svd, svdvals!, eigen, eigvals, Hermitian, qr
using LinearAlgebra: I, Diagonal

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_svd_QRIteration()
end

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix{T}, USVᴴ, alg::GLA_svd_QRIteration) where {T}
check_input(svd_compact!, A, USVᴴ, alg)
U, S, Vᴴ = USVᴴ
Ũ, S̃, Ṽ = svd(A)
copyto!(U, Ũ)
copyto!(S, Diagonal(S̃))
copyto!(Vᴴ, Ṽ') # conjugation to account for difference in convention
return U, S, Vᴴ
end

function MatrixAlgebraKit.svd_full!(A::AbstractMatrix{T}, USVᴴ, alg::GLA_svd_QRIteration) where {T}
check_input(svd_full!, A, USVᴴ, alg)
U, S, Vᴴ = USVᴴ
m, n = size(A)
minmn = min(m, n)
if minmn == 0
MatrixAlgebraKit.one!(U)
MatrixAlgebraKit.zero!(S)
MatrixAlgebraKit.one!(Vᴴ)
return USVᴴ
end
= fill!(S, zero(T))
U_compact, S_compact, V_compact = svd(A)
S̃[1:minmn, 1:minmn] .= Diagonal(S_compact)
copyto!(S, S̃)

U = _gram_schmidt!(U, U_compact)
Vᴴ = _gram_schmidt!(Vᴴ, V_compact; adjoint = true)

return MatrixAlgebraKit.gaugefix!(svd_full!, U, S, Vᴴ, m, n)
end

function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix{T}, S, alg::GLA_svd_QRIteration) where {T}
check_input(svd_vals!, A, S, alg)
= svdvals!(A)
copyto!(S, S̃)
return S
end

function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_eigh_Francis(; kwargs...)
end

function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix{T}, DV, alg::GLA_eigh_Francis) where {T}
check_input(eigh_full!, A, DV, alg)
D, V = DV
eigval, eigvec = eigen(Hermitian(A); sortby = real)
copyto!(D, Diagonal(eigval))
copyto!(V, eigvec)
return D, V
end

function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix{T}, D, alg::GLA_eigh_Francis) where {T}
check_input(eigh_vals!, A, D, alg)
D = eigvals(A; sortby = real)
return real.(D)
end

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QR_Householder(; kwargs...)
end

function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_QR_Householder)
Q, R = QR
m, n = size(A)
minmn = min(m, n)
computeR = length(R) > 0

Q_zero = zeros(eltype(Q), (m, minmn))
R_zero = zeros(eltype(R), (minmn, n))
Q_compact, R_compact = _gla_householder_qr!(A, Q_zero, R_zero; alg.kwargs...)
Q = _gram_schmidt!(Q, Q_compact[:, 1:min(m, n)])
if computeR
R = fill!(R, zero(eltype(R)))
R[1:minmn, 1:n] .= R_compact
end
return Q, R
end

function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_QR_Householder)
check_input(qr_compact!, A, QR, alg)
Q, R = QR
Q, R = _gla_householder_qr!(A, Q, R; alg.kwargs...)
return Q, R
end

function _gla_householder_qr!(A::AbstractMatrix{T}, Q, R; positive = false, blocksize = 1, pivoted = false) where {T}
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_QR_Householder."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_QR_Householder."))

m, n = size(A)
k = min(m, n)
computeR = length(R) > 0
Q̃, R̃ = qr(A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can definitely be improved. GenericLinearAlgebra.jl has both a qrUnblocked! and a qrBlocked!. They act similar to Lapack, i.e. after calling F = qrUnblocked!(A), the elements on and above the diagonal of A (with F.factors === A) encode R, whereas the elements below the diagonal of A as well as a new vector F.τ encode the Householder reflectors that encode Q. From this, you can reconstruct both qr_compact!, qr_full! as well as qr_null!.

However, all of this is quite cryptic if you are not familiar with these ancient LAPACK storage strategies.

= convert(Array, Q̃)
if positive
@inbounds for j in 1:k
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q̃[i, j] *= s
end
end
end
copyto!(Q, Q̃)
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(k, j)
R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
end
end
copyto!(R, R̃)
end
return Q, R
end

function _gram_schmidt(Q_compact)
m, minmn = size(Q_compact)
if minmn >= m
return Q_compact
end
Q = zeros(eltype(Q_compact), (m, m))
Q[:, 1:minmn] .= Q_compact
for j in (minmn + 1):m
v = rand(eltype(Q), m)
for i in 1:(j - 1)
r = sum([v[k] * conj(Q[k, i])] for k in 1:size(v)[1])[1]
v .= v .- r * Q[:, i]
end
Q[:, j] = v ./ MatrixAlgebraKit.norm(v)
end
return Q
end

function _gram_schmidt!(Q, Q_compact; adjoint = false)
= _gram_schmidt(Q_compact)
if adjoint
copyto!(Q, Q̃')
else
copyto!(Q, Q̃)
end
return Q
end

function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return MatrixAlgebraKit.LQViaTransposedQR(GLA_QR_Householder(; kwargs...))
end

end
28 changes: 28 additions & 0 deletions ext/MatrixAlgebraKitGenericSchurExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module MatrixAlgebraKitGenericSchurExt

using MatrixAlgebraKit
using MatrixAlgebraKit: check_input
using LinearAlgebra: Diagonal
using GenericSchur

function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GS_eig_Francis(; kwargs...)
end

function MatrixAlgebraKit.eig_full!(A::AbstractMatrix{T}, DV, alg::GS_eig_Francis) where {T}
check_input(eig_full!, A, DV, alg)
D, V = DV
D̃, Ṽ = GenericSchur.eigen!(A)
copyto!(D, Diagonal(D̃))
copyto!(V, Ṽ)
return D, V
end

function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix{T}, D, alg::GS_eig_Francis) where {T}
check_input(eig_vals!, A, D, alg)
eigval = GenericSchur.eigvals!(A)
copyto!(D, eigval)
return D
end

end
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export GLA_QR_Householder, GS_eig_Francis, GLA_eigh_Francis, GLA_svd_QRIteration
export LQViaTransposedQR
export PolarViaSVD, PolarNewton
export DiagonalAlgorithm
Expand Down
10 changes: 10 additions & 0 deletions src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ elements of `L` are non-negative.
@algdef LAPACK_HouseholderLQ

# TODO:
@algdef GLA_QR_Householder
@algdef LAPACK_HouseholderQL
@algdef LAPACK_HouseholderRQ

Expand All @@ -56,6 +57,9 @@ eigenvalue decomposition of a matrix.

const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert}

# TODO:
@algdef GS_eig_Francis

# Hermitian Eigenvalue Decomposition
# ----------------------------------
"""
Expand Down Expand Up @@ -100,6 +104,9 @@ const LAPACK_EighAlgorithm = Union{
LAPACK_MultipleRelativelyRobustRepresentations,
}

# TODO:
@algdef GLA_eigh_Francis

# Singular Value Decomposition
# ----------------------------
"""
Expand All @@ -117,6 +124,9 @@ const LAPACK_SVDAlgorithm = Union{
LAPACK_Jacobi,
}

# TODO:
@algdef GLA_svd_QRIteration

# =========================
# Polar decompositions
# =========================
Expand Down
116 changes: 116 additions & 0 deletions test/bigfloats/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: Diagonal
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
using GenericSchur

const eltypes = (BigFloat, Complex{BigFloat})

@testset "eig_full! for T = $T" for T in eltypes
rng = StableRNG(123)
m = 24
alg = GS_eig_Francis()
A = randn(rng, T, m, m)
Tc = complex(T)

D, V = @constinferred eig_full(A; alg = ($alg))
@test eltype(D) == eltype(V) == Tc
@test A * V ≈ V * D

alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg)

Ac = similar(A)
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′)
@test D2 === D
@test V2 === V
@test A * V ≈ V * D

Dc = @constinferred eig_vals(A, alg′)
@test eltype(Dc) == Tc
@test D ≈ Diagonal(Dc)
end

@testset "eig_trunc! for T = $T" for T in eltypes
rng = StableRNG(123)
m = 6
alg = GS_eig_Francis()
A = randn(rng, T, m, m)
A *= A' # TODO: deal with eigenvalue ordering etc
# eigenvalues are sorted by ascending real component...
D₀ = sort!(eig_vals(A); by = abs, rev = true)
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
r = length(D₀) - rmin
atol = sqrt(eps(real(T)))

D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r))
D1base, V1base = @constinferred eig_full(A; alg)

@test length(diagview(D1)) == r
@test A * V1 ≈ V1 * D1
@test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 + sqrt(eps(real(T)))
trunc = trunctol(; atol = s * abs(D₀[r + 1]))
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
@test length(diagview(D2)) == r
@test A * V2 ≈ V2 * D2
@test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc)
@test length(diagview(D3)) == r
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol

# trunctol keeps order, truncrank might not
# test for same subspace
@test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2
@test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
@test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3
@test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1
end

@testset "eig_trunc! specify truncation algorithm T = $T" for T in eltypes
rng = StableRNG(123)
m = 4
atol = sqrt(eps(real(T)))
V = randn(rng, T, m, m)
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
A = V * D * inv(V)
alg = TruncatedAlgorithm(GS_eig_Francis(), truncrank(2))
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2]
@test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol
@test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))

alg = TruncatedAlgorithm(GS_eig_Francis(), truncerror(; atol = 0.2, p = 1))
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
@test diagview(D3) ≈ diagview(D)[1:2]
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
end

@testset "eig for Diagonal{$T}" for T in eltypes
rng = StableRNG(123)
m = 24
Ad = randn(rng, T, m)
A = Diagonal(Ad)
atol = sqrt(eps(real(T)))

D, V = @constinferred eig_full(A)
@test D isa Diagonal{T} && size(D) == size(A)
@test V isa Diagonal{T} && size(V) == size(A)
@test A * V ≈ V * D

D2 = @constinferred eig_vals(A)
@test D2 isa AbstractVector{T} && length(D2) == m
@test diagview(D) ≈ D2

A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
@test diagview(D2) ≈ diagview(A2)[1:2]
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol
end
Loading