Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
.*.swp
Manifest.toml
docs/build/
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <[email protected]> and contributors"]
version = "0.1.1"
version = "0.1.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -18,6 +18,7 @@ ChainRulesCore = "1"
ChainRulesTestUtils = "1"
JET = "0.9"
LinearAlgebra = "1"
SafeTestsets = "0.1"
StableRNGs = "1"
Test = "1"
TestExtras = "0.2,0.3"
Expand All @@ -28,10 +29,11 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"]
2 changes: 1 addition & 1 deletion ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ,
return PWᴴ, right_polar_pullback
end

end
end
2 changes: 1 addition & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,4 @@ macro check_size(x, sz, size=:size)
string($sz)
szx == $sz || throw(DimensionMismatch($err))
end)
end
end
2 changes: 1 addition & 1 deletion src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,4 @@ function right_null!(A::AbstractMatrix, Nᴴ; kwargs...)
else
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
end
end
end
40 changes: 37 additions & 3 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
return NoTruncation()
elseif isnothing(maxrank)
@assert isnothing(rtol) "TODO: rtol"
return trunctol(atol)
atol = @something atol 0
rtol = @something rtol 0
return TruncationKeepAbove(atol, rtol)
else
return truncrank(maxrank)
if isnothing(atol) && isnothing(rtol)
return truncrank(maxrank)
else
atol = @something atol 0
rtol = @something rtol 0
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
end
end
end

Expand Down Expand Up @@ -82,6 +89,28 @@ Truncation strategy to discard the values that are larger than `atol` in absolut
"""
truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs)

"""
TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy)

Compose two truncation strategies, keeping values common between the two strategies.
"""
struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <:
TruncationStrategy
components::T
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationComposition((trunc1, trunc2))
end
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition)
return TruncationComposition((trunc1.components..., trunc2.components...))
end
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy)
return TruncationComposition((trunc1.components..., trunc2))
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition)
return TruncationComposition((trunc1, trunc2.components...))
end

# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
Expand Down Expand Up @@ -147,6 +176,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
return 1:i
end

function findtruncated(values::AbstractVector, strategy::TruncationComposition)
inds = map(Base.Fix1(findtruncated, values), strategy.components)
return intersect(inds...)
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand Down
6 changes: 3 additions & 3 deletions src/interface/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ of `kind`.
the input matrix `A`. Always use the return value of the function as it may not always be
possible to use the provided `CV` as output.

See also [`right_orth(!)`](@ref right_orth), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth)
See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null)
"""
function left_orth end
function left_orth! end
Expand Down Expand Up @@ -117,7 +117,7 @@ of `kind`.
the input matrix `A`. Always use the return value of the function as it may not always be
possible to use the provided `CVᴴ` as output.

See also [`left_orth(!)`](@ref left_orth), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth)
See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null)
"""
function right_orth end
function right_orth! end
Expand Down Expand Up @@ -224,4 +224,4 @@ function right_null!(A::AbstractMatrix; kwargs...)
end
function right_null(A::AbstractMatrix; kwargs...)
return right_null!(copy_input(right_null, A); kwargs...)
end
end
2 changes: 1 addition & 1 deletion src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ function right_polar_pullback!(ΔA::AbstractMatrix, PWᴴ, ΔPWᴴ)
ΔA .+= PΔWᴴ
end
return ΔA
end
end
8 changes: 6 additions & 2 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using ChainRulesCore, ChainRulesTestUtils, Zygote
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!

function remove_svdgauge_depence!(ΔU, ΔVᴴ, U, S, Vᴴ;
degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S))
Expand Down Expand Up @@ -356,4 +360,4 @@ end
test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ,
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
end
end
end
7 changes: 7 additions & 0 deletions test/eig.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: Diagonal
using MatrixAlgebraKit: diagview

@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
m = 54
Expand Down
7 changes: 7 additions & 0 deletions test/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I
using MatrixAlgebraKit: diagview

@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
m = 54
Expand Down
6 changes: 6 additions & 0 deletions test/lq.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: diag, I

@testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
m = 54
Expand Down
8 changes: 7 additions & 1 deletion test/orthnull.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, I

@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
Expand Down Expand Up @@ -209,4 +215,4 @@ end
end
end
end
end
end
5 changes: 5 additions & 0 deletions test/polar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, I, isposdef
using MatrixAlgebraKit: PolarViaSVD

@testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
Expand Down
6 changes: 6 additions & 0 deletions test/qr.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: diag, I

@testset "qr_compact! and qr_null! for T = $T" for T in (Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
Expand Down
36 changes: 16 additions & 20 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,40 @@
using MatrixAlgebraKit
using Test
using TestExtras
using ChainRulesTestUtils
using StableRNGs
using Aqua
using JET
using LinearAlgebra: LinearAlgebra, diag, Diagonal, I, isposdef, diagind, mul!
using MatrixAlgebraKit: diagview
using SafeTestsets

@testset "QR / LQ Decomposition" begin
@safetestset "QR / LQ Decomposition" begin
include("qr.jl")
include("lq.jl")
end
@testset "Singular Value Decomposition" begin
@safetestset "Singular Value Decomposition" begin
include("svd.jl")
end
@testset "Hermitian Eigenvalue Decomposition" begin
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("eigh.jl")
end
@testset "General Eigenvalue Decomposition" begin
@safetestset "General Eigenvalue Decomposition" begin
include("eig.jl")
end
@testset "Schur Decomposition" begin
@safetestset "Schur Decomposition" begin
include("schur.jl")
end
@testset "Polar Decomposition" begin
@safetestset "Polar Decomposition" begin
include("polar.jl")
end
@testset "Image and Null Space" begin
@safetestset "Image and Null Space" begin
include("orthnull.jl")
end
@testset "ChainRules" verbose = true begin
@safetestset "ChainRules" begin
include("chainrules.jl")
end

@testset "MatrixAlgebraKit.jl" begin
@testset "Code quality (Aqua.jl)" begin
@safetestset "MatrixAlgebraKit.jl" begin
@safetestset "Code quality (Aqua.jl)" begin
using MatrixAlgebraKit
using Aqua
Aqua.test_all(MatrixAlgebraKit)
end
@testset "Code linting (JET.jl)" begin
@safetestset "Code linting (JET.jl)" begin
using MatrixAlgebraKit
using JET
JET.test_package(MatrixAlgebraKit; target_defined_modules=true)
end
end
6 changes: 6 additions & 0 deletions test/schur.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: I

@testset "schur_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
m = 54
Expand Down
34 changes: 34 additions & 0 deletions test/svd.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef
using MatrixAlgebraKit: diagview

@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
m = 54
Expand Down Expand Up @@ -108,3 +115,30 @@ end
end
end
end

@testset "svd_trunc! mix maxrank and tol for T = $T" for T in
(Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
if LinearAlgebra.LAPACK.version() < v"3.12.0"
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
else
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
LAPACK_Jacobi())
end
m = 4
@testset "algorithm $alg" for alg in algs
U = qr_compact(randn(rng, T, m, m))[1]
S = Diagonal([0.9, 0.3, 0.1, 0.01])
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
A = U * S * Vᴴ

U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1))
@test length(S1.diag) == 1
@test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T)))

U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3))
@test length(S2.diag) == 2
@test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T)))
end
end