Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
93f37f5
add GenericExt
sanderdemeyer Oct 29, 2025
5e08f7a
include weak dependencies
sanderdemeyer Oct 29, 2025
92534bc
Update ext/MatrixAlgebraKitGenericExt.jl
sanderdemeyer Oct 30, 2025
ddaf055
Update ext/MatrixAlgebraKitGenericExt.jl
sanderdemeyer Oct 30, 2025
88b9902
comments from Lukas
sanderdemeyer Oct 30, 2025
3eee758
remove `BigFloat_LQ_Householder`
sanderdemeyer Oct 30, 2025
f8d331e
Change struct names and relax type restrictions
sanderdemeyer Oct 30, 2025
03caa56
Split extensions
sanderdemeyer Oct 30, 2025
faa0921
fix copies
sanderdemeyer Oct 31, 2025
9ed02bd
fix type instability
sanderdemeyer Oct 31, 2025
c3e069e
Name change
sanderdemeyer Oct 31, 2025
7af4d52
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
c6ba4e4
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
fdb1222
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
49404bf
Merge branch 'QuantumKitHub:main' into GenericLinearAlgebraExt
sanderdemeyer Nov 4, 2025
837a7a4
Resolve comments
sanderdemeyer Nov 4, 2025
78f92f7
change folder names
sanderdemeyer Nov 4, 2025
fb1994f
add docs and namechange
sanderdemeyer Nov 4, 2025
3a36446
Remove unnecessary type parameters
lkdvos Nov 4, 2025
e6f64b2
simplify SVD and remove allocations
lkdvos Nov 4, 2025
4323b17
simplify Eigh and remove allocations
lkdvos Nov 4, 2025
f99112b
simplify QR
lkdvos Nov 4, 2025
89fd42b
simplify Eig and remove allocations
lkdvos Nov 4, 2025
35af44d
switch to `lmul!`
lkdvos Nov 5, 2025
4fe0fe1
docstring improvements
lkdvos Nov 5, 2025
c08dad4
move Diagonal tests
lkdvos Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"

[extensions]
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
MatrixAlgebraKitCUDAExt = "CUDA"
MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra"
MatrixAlgebraKitGenericSchurExt = "GenericSchur"

[compat]
AMDGPU = "2"
Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
CUDA = "5"
GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
LinearAlgebra = "1"
SafeTestsets = "0.1"
Expand All @@ -42,4 +48,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
108 changes: 108 additions & 0 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration()
end

for f! in (:svd_compact!, :svd_full!, :svd_vals!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
end

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration)
F = svd!(A)
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt
return MatrixAlgebraKit.gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
end

function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration)
F = svd!(A; full = true)
U, Vᴴ = F.U, F.Vt
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
diagview(S) .= F.S
return MatrixAlgebraKit.gaugefix!(svd_full!, U, S, Vᴴ, size(A)...)
end

function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
return svdvals!(A)
end

function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration(; kwargs...)
end

for f! in (:eigh_full!, :eigh_vals!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
end

function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
end

function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
return eigvals!(Hermitian(A); sortby = real)
end

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_HouseholderQR(; kwargs...)
end

function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_full!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end

function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_compact!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end

function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksize = 1, pivoted = false)
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))

m, n = size(A)
k = min(m, n)
Q̃, R̃ = qr!(A)
lmul!(Q̃, MatrixAlgebraKit.one!(Q))

if positive
@inbounds for j in 1:k
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
end

computeR = length(R) > 0
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(k, j)
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
@simd for i in (min(k, j) + 1):size(R, 1)
R[i, j] = zero(eltype(R))
end
end
else
R[1:k, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :]))
end
end
return Q, R
end

function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
end

end
25 changes: 25 additions & 0 deletions ext/MatrixAlgebraKitGenericSchurExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module MatrixAlgebraKitGenericSchurExt

using MatrixAlgebraKit
using MatrixAlgebraKit: check_input
using LinearAlgebra: Diagonal
using GenericSchur

function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GS_QRIteration(; kwargs...)
end

for f! in (:eig_full!, :eig_vals!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GS_QRIteration) = nothing
end

function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration)
D, V = GenericSchur.eigen!(A)
return Diagonal(D), V
end

function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
return GenericSchur.eigvals!(A)
end

end
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration
export LQViaTransposedQR
export PolarViaSVD, PolarNewton
export DiagonalAlgorithm
Expand Down
31 changes: 29 additions & 2 deletions src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Algorithm type to denote the standard LAPACK algorithm for computing the QR deco
a matrix using Householder reflectors. The specific LAPACK function can be controlled using
the keyword arugments, i.e. `?geqrt` will be chosen if `blocksize > 1`. With
`blocksize == 1`, `?geqrf` will be chosen if `pivoted == false` and `?geqp3` will be chosen
if `pivoted == true`. The keyword `positive=true` can be used to ensure that the diagonal
if `pivoted == true`. The keyword `positive = true` can be used to ensure that the diagonal
elements of `R` are non-negative.
"""
@algdef LAPACK_HouseholderQR
Expand All @@ -27,11 +27,21 @@ elements of `R` are non-negative.
Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of
a matrix using Householder reflectors. The specific LAPACK function can be controlled using
the keyword arugments, i.e. `?gelqt` will be chosen if `blocksize > 1` or `?gelqf` will be
chosen if `blocksize == 1`. The keyword `positive=true` can be used to ensure that the diagonal
chosen if `blocksize == 1`. The keyword `positive = true` can be used to ensure that the diagonal
elements of `L` are non-negative.
"""
@algdef LAPACK_HouseholderLQ

"""
GLA_HouseholderQR(; positive = false)

Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the QR decomposition
of a matrix using Householder reflectors. Currently, only `blocksize = 1` and `pivoted == false`
are supported. The keyword `positive = true` can be used to ensure that the diagonal elements
of `R` are non-negative.
"""
@algdef GLA_HouseholderQR

# TODO:
@algdef LAPACK_HouseholderQL
@algdef LAPACK_HouseholderRQ
Expand All @@ -56,6 +66,14 @@ eigenvalue decomposition of a matrix.

const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert}

"""
GS_QRIteration()

Algorithm type to denote the GenericSchur.jl implementation for computing the
eigenvalue decomposition of a non-Hermitian matrix.
"""
@algdef GS_QRIteration

# Hermitian Eigenvalue Decomposition
# ----------------------------------
"""
Expand Down Expand Up @@ -100,6 +118,15 @@ const LAPACK_EighAlgorithm = Union{
LAPACK_MultipleRelativelyRobustRepresentations,
}

"""
GLA_QRIteration()

Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the
eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of
a general matrix.
"""
@algdef GLA_QRIteration

# Singular Value Decomposition
# ----------------------------
"""
Expand Down
5 changes: 3 additions & 2 deletions test/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using StableRNGs
using LinearAlgebra: Diagonal
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm

const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
GenericFloats = (Float16, BigFloat, Complex{BigFloat})

@testset "eig_full! for T = $T" for T in BLASFloats
rng = StableRNG(123)
Expand Down Expand Up @@ -91,7 +92,7 @@ end
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
end

@testset "eig for Diagonal{$T}" for T in BLASFloats
@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
rng = StableRNG(123)
m = 54
Ad = randn(rng, T, m)
Expand Down
5 changes: 3 additions & 2 deletions test/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm

const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
GenericFloats = (Float16, BigFloat, Complex{BigFloat})

@testset "eigh_full! for T = $T" for T in BLASFloats
rng = StableRNG(123)
Expand Down Expand Up @@ -100,7 +101,7 @@ end
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
end

@testset "eigh for Diagonal{$T}" for T in BLASFloats
@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
rng = StableRNG(123)
m = 54
Ad = randn(rng, T, m)
Expand Down
93 changes: 93 additions & 0 deletions test/genericlinearalgebra/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
using GenericLinearAlgebra

const eltypes = (BigFloat, Complex{BigFloat})

@testset "eigh_full! for T = $T" for T in eltypes
rng = StableRNG(123)
m = 54
alg = GLA_QRIteration()

A = randn(rng, T, m, m)
A = (A + A') / 2

D, V = @constinferred eigh_full(A; alg)
@test A * V ≈ V * D
@test isunitary(V)
@test all(isreal, D)

D2, V2 = eigh_full!(copy(A), (D, V), alg)
@test D2 ≈ D
@test V2 ≈ V

D3 = @constinferred eigh_vals(A, alg)
@test D ≈ Diagonal(D3)
end

@testset "eigh_trunc! for T = $T" for T in eltypes
rng = StableRNG(123)
m = 54
alg = GLA_QRIteration()
A = randn(rng, T, m, m)
A = A * A'
A = (A + A') / 2
Ac = similar(A)
D₀ = reverse(eigh_vals(A))

r = m - 2
s = 1 + sqrt(eps(real(T)))
atol = sqrt(eps(real(T)))

D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r))
Dfull, Vfull = eigh_full(A; alg)
@test length(diagview(D1)) == r
@test isisometric(V1)
@test A * V1 ≈ V1 * D1
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
@test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol

trunc = trunctol(; atol = s * D₀[r + 1])
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
@test length(diagview(D2)) == r
@test isisometric(V2)
@test A * V2 ≈ V2 * D2
@test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc)
@test length(diagview(D3)) == r
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol

# test for same subspace
@test V1 * (V1' * V2) ≈ V2
@test V2 * (V2' * V1) ≈ V1
@test V1 * (V1' * V3) ≈ V3
@test V3 * (V3' * V1) ≈ V1
end

@testset "eigh_trunc! specify truncation algorithm T = $T" for T in eltypes
rng = StableRNG(123)
m = 4
atol = sqrt(eps(real(T)))
V = qr_compact(randn(rng, T, m, m))[1]
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
A = V * D * V'
A = (A + A') / 2
alg = TruncatedAlgorithm(GLA_QRIteration(), truncrank(2))
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2]
@test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2))
@test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol

alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2))
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
@test diagview(D3) ≈ diagview(D)[1:2]
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
end
Loading