Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
93f37f5
add GenericExt
sanderdemeyer Oct 29, 2025
5e08f7a
include weak dependencies
sanderdemeyer Oct 29, 2025
92534bc
Update ext/MatrixAlgebraKitGenericExt.jl
sanderdemeyer Oct 30, 2025
ddaf055
Update ext/MatrixAlgebraKitGenericExt.jl
sanderdemeyer Oct 30, 2025
88b9902
comments from Lukas
sanderdemeyer Oct 30, 2025
3eee758
remove `BigFloat_LQ_Householder`
sanderdemeyer Oct 30, 2025
f8d331e
Change struct names and relax type restrictions
sanderdemeyer Oct 30, 2025
03caa56
Split extensions
sanderdemeyer Oct 30, 2025
faa0921
fix copies
sanderdemeyer Oct 31, 2025
9ed02bd
fix type instability
sanderdemeyer Oct 31, 2025
c3e069e
Name change
sanderdemeyer Oct 31, 2025
7af4d52
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
c6ba4e4
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
fdb1222
Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
sanderdemeyer Nov 4, 2025
49404bf
Merge branch 'QuantumKitHub:main' into GenericLinearAlgebraExt
sanderdemeyer Nov 4, 2025
837a7a4
Resolve comments
sanderdemeyer Nov 4, 2025
78f92f7
change folder names
sanderdemeyer Nov 4, 2025
fb1994f
add docs and namechange
sanderdemeyer Nov 4, 2025
3a36446
Remove unnecessary type parameters
lkdvos Nov 4, 2025
e6f64b2
simplify SVD and remove allocations
lkdvos Nov 4, 2025
4323b17
simplify Eigh and remove allocations
lkdvos Nov 4, 2025
f99112b
simplify QR
lkdvos Nov 4, 2025
89fd42b
simplify Eig and remove allocations
lkdvos Nov 4, 2025
35af44d
switch to `lmul!`
lkdvos Nov 5, 2025
4fe0fe1
docstring improvements
lkdvos Nov 5, 2025
c08dad4
move Diagonal tests
lkdvos Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,23 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"

[extensions]
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
MatrixAlgebraKitCUDAExt = "CUDA"
MatrixAlgebraKitGenericExt = ["GenericLinearAlgebra", "GenericSchur"]

[compat]
AMDGPU = "2"
Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
CUDA = "5"
GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
LinearAlgebra = "1"
SafeTestsets = "0.1"
Expand All @@ -42,4 +47,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
230 changes: 230 additions & 0 deletions ext/MatrixAlgebraKitGenericExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
module MatrixAlgebraKitGenericExt

using MatrixAlgebraKit
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, LAPACK_EigAlgorithm, LAPACK_EighAlgorithm, LAPACK_QRIteration
using MatrixAlgebraKit: uppertriangular!
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: sign_safe
using GenericLinearAlgebra
using GenericSchur
using LinearAlgebra: I, Diagonal

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return BigFloat_svd_QRIteration()
end

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix{T}, USVᴴ, alg::BigFloat_svd_QRIteration) where {T <: Union{BigFloat, Complex{BigFloat}}}
check_input(svd_compact!, A, USVᴴ, alg)
U, S, V = GenericLinearAlgebra.svd(A)
return U, Diagonal(S), V' # conjugation to account for difference in convention
end

function MatrixAlgebraKit.svd_full!(A::AbstractMatrix{T}, USVᴴ, alg::BigFloat_svd_QRIteration)::Tuple{Matrix{T}, Matrix{BigFloat}, Matrix{T}} where {T <: Union{BigFloat, Complex{BigFloat}}}
check_input(svd_full!, A, USVᴴ, alg)
U, S, Vᴴ = USVᴴ
m, n = size(A)
minmn = min(m, n)
if minmn == 0
MatrixAlgebraKit.one!(U)
MatrixAlgebraKit.zero!(S)
MatrixAlgebraKit.one!(Vᴴ)
return USVᴴ
end
= zeros(eltype(S), size(S))
U_compact, S_compact, V_compact = GenericLinearAlgebra.svd(A)
S̃[1:minmn, 1:minmn] .= Diagonal(S_compact)
= _gram_schmidt(U_compact)
= _gram_schmidt(V_compact)

copyto!(U, Ũ)
copyto!(S, S̃)
copyto!(Vᴴ, Ṽ')

return MatrixAlgebraKit.gaugefix!(svd_full!, U, S, Vᴴ, m, n)
end

function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix{T}, S, alg::BigFloat_svd_QRIteration) where {T <: Union{BigFloat, Complex{BigFloat}}}
check_input(svd_vals!, A, S, alg)
= GenericLinearAlgebra.svdvals!(A)
copyto!(S, S̃)
return S
end

function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return BigFloat_eig_Francis(; kwargs...)
end

function MatrixAlgebraKit.eig_full!(A::AbstractMatrix{T}, DV, alg::BigFloat_eig_Francis)::Tuple{Diagonal{Complex{BigFloat}}, Matrix{Complex{BigFloat}}} where {T <: Union{BigFloat, Complex{BigFloat}}}
D, V = DV
D̃, Ṽ = GenericSchur.eigen!(A)
copyto!(D, Diagonal(D̃))
copyto!(V, Ṽ)
return D, V
end

function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix{T}, D, alg::BigFloat_eig_Francis)::Vector{Complex{BigFloat}} where {T <: Union{BigFloat, Complex{BigFloat}}}
check_input(eig_vals!, A, D, alg)
eigval = GenericSchur.eigvals!(A)
return eigval
end


function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return BigFloat_eigh_Francis(; kwargs...)
end

function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix{T}, DV, alg::BigFloat_eigh_Francis)::Tuple{Diagonal{BigFloat}, Matrix{T}} where {T <: Union{BigFloat, Complex{BigFloat}}}
check_input(eigh_full!, A, DV, alg)
eigval, eigvec = GenericLinearAlgebra.eigen(A; sortby = λ -> real(λ))
return Diagonal(eigval), eigvec
end

function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix{T}, D, alg::BigFloat_eigh_Francis)::Vector{BigFloat} where {T <: Union{BigFloat, Complex{BigFloat}}}
check_input(eigh_vals!, A, D, alg)
D = GenericLinearAlgebra.eigvals(A; sortby = λ -> real(λ))
return real.(D)
end

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return BigFloat_QR_Householder(; kwargs...)
end

function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::BigFloat_QR_Householder)
Q, R = QR
m, n = size(A)
minmn = min(m, n)
computeR = length(R) > 0

Q_zero = zeros(eltype(Q), (m, minmn))
R_zero = zeros(eltype(R), (minmn, n))
Q_compact, R_compact = _bigfloat_householder_qr!(A, Q_zero, R_zero; alg.kwargs...)
copyto!(Q, _gram_schmidt(Q_compact[:, 1:min(m, n)]))
if computeR
= zeros(eltype(R), m, n)
R̃[1:minmn, 1:n] .= R_compact
copyto!(R, R̃)
end
return Q, R
end

function MatrixAlgebraKit.lq_full!(A::AbstractMatrix, LQ, alg::BigFloat_LQ_Householder)
L, Q = LQ
m, n = size(A)
minmn = min(m, n)
computeL = length(L) > 0

L_zero = zeros(eltype(L), (m, minmn))
Q_zero = zeros(eltype(Q), (minmn, n))
L_compact, Q_compact = _bigfloat_householder_lq!(A, L_zero, Q_zero; alg.kwargs...)
copyto!(Q, _gram_schmidt(Q_compact'[:, 1:min(m, n)])')
if computeL
= zeros(eltype(L), m, n)
L̃[1:m, 1:minmn] .= L_compact
copyto!(L, L̃)
end
return L, Q
end

function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::BigFloat_QR_Householder)
check_input(qr_compact!, A, QR, alg)
Q, R = QR
Q, R = _bigfloat_householder_qr!(A, Q, R; alg.kwargs...)
return Q, R
end

function _bigfloat_householder_qr!(A::AbstractMatrix{T}, Q, R; positive = false, blocksize = 1, pivoted = false) where {T <: Union{BigFloat, Complex{BigFloat}}}
pivoted && throw(ArgumentError("Only pivoted = false implemented for BigFloats."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for BigFloats."))

m, n = size(A)
k = min(m, n)
computeR = length(R) > 0
Q̃, R̃ = GenericLinearAlgebra.qr(A)
= convert(Array, Q̃)
if positive
@inbounds for j in 1:k
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q̃[i, j] *= s
end
end
end
copyto!(Q, Q̃)
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(k, j)
R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
end
end
copyto!(R, R̃)
end
return Q, R
end

function _gram_schmidt(Q_compact)
m, minmn = size(Q_compact)
if minmn >= m
return Q_compact
end
Q = zeros(eltype(Q_compact), (m, m))
Q[:, 1:minmn] .= Q_compact
for j in (minmn + 1):m
v = rand(eltype(Q), m)
for i in 1:(j - 1)
r = sum([v[k] * conj(Q[k, i])] for k in 1:size(v)[1])[1]
v .= v .- r * Q[:, i]
end
Q[:, j] = v ./ MatrixAlgebraKit.norm(v)
end
return Q
end

function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return BigFloat_LQ_Householder(; kwargs...)
end

function MatrixAlgebraKit.lq_compact!(A::AbstractMatrix, LQ, alg::BigFloat_LQ_Householder)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
L, Q = _bigfloat_householder_lq!(A, L, Q; alg.kwargs...)
return L, Q
end


function _bigfloat_householder_lq!(A::AbstractMatrix{T}, L, Q; positive = false, blocksize = 1, pivoted = false) where {T <: Union{BigFloat, Complex{BigFloat}}}
pivoted && throw(ArgumentError("Only pivoted = false implemented for BigFloats."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for BigFloats."))

m, n = size(A)
k = min(m, n)
computeL = length(L) > 0

Q̃, R̃ = GenericLinearAlgebra.qr(A')
= convert(Array, Q̃)

if positive
@inbounds for j in 1:k
s = sign_safe(R̃[j, j])
@simd for i in 1:n
Q̃[i, j] *= s
end
end
end
copyto!(Q, Q̃')
if computeL
if positive
@inbounds for j in m:-1:1
for i in 1:min(k, j)
R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
end
end
copyto!(L, R̃')

end
return L, Q
end

end
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export BigFloat_QR_Householder, BigFloat_LQ_Householder, BigFloat_eig_Francis, BigFloat_eigh_Francis, BigFloat_svd_QRIteration
export LQViaTransposedQR
export PolarViaSVD, PolarNewton
export DiagonalAlgorithm
Expand Down
11 changes: 11 additions & 0 deletions src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ elements of `L` are non-negative.
@algdef LAPACK_HouseholderLQ

# TODO:
@algdef BigFloat_QR_Householder
@algdef BigFloat_LQ_Householder
@algdef LAPACK_HouseholderQL
@algdef LAPACK_HouseholderRQ

Expand All @@ -56,6 +58,9 @@ eigenvalue decomposition of a matrix.

const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert}

# TODO:
@algdef BigFloat_eig_Francis

# Hermitian Eigenvalue Decomposition
# ----------------------------------
"""
Expand Down Expand Up @@ -100,6 +105,9 @@ const LAPACK_EighAlgorithm = Union{
LAPACK_MultipleRelativelyRobustRepresentations,
}

# TODO:
@algdef BigFloat_eigh_Francis

# Singular Value Decomposition
# ----------------------------
"""
Expand All @@ -117,6 +125,9 @@ const LAPACK_SVDAlgorithm = Union{
LAPACK_Jacobi,
}

# TODO:
@algdef BigFloat_svd_QRIteration

# =========================
# Polar decompositions
# =========================
Expand Down
Loading