Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <[email protected]> and contributors"]
version = "0.1.1"
version = "0.1.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
40 changes: 37 additions & 3 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
return NoTruncation()
elseif isnothing(maxrank)
@assert isnothing(rtol) "TODO: rtol"
return trunctol(atol)
atol = @something atol 0
rtol = @something rtol 0
return TruncationKeepAbove(atol, rtol)
else
return truncrank(maxrank)
if isnothing(atol) && isnothing(rtol)
return truncrank(maxrank)
else
atol = @something atol 0
rtol = @something rtol 0
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
end
end
end

Expand Down Expand Up @@ -82,6 +89,28 @@
"""
truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs)

"""
TruncationIntersection(trunc1::TruncationStrategy, trunc2::TruncationStrategy)

Compose two truncation strategies, keeping values common between the two strategies.
"""
struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <:
TruncationStrategy
components::T
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationIntersection((trunc1, trunc2))
end
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationIntersection)
return TruncationIntersection((trunc1.components..., trunc2.components...))

Check warning on line 105 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L104-L105

Added lines #L104 - L105 were not covered by tests
end
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationStrategy)
return TruncationIntersection((trunc1.components..., trunc2))

Check warning on line 108 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection)
return TruncationIntersection((trunc1, trunc2.components...))

Check warning on line 111 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L110-L111

Added lines #L110 - L111 were not covered by tests
end

# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
Expand Down Expand Up @@ -147,6 +176,11 @@
return 1:i
end

function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(findtruncated, values), strategy.components)
return intersect(inds...)
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using SafeTestsets

@safetestset "Truncate" begin
include("truncate.jl")
end
@safetestset "QR / LQ Decomposition" begin
include("qr.jl")
include("lq.jl")
Expand Down
32 changes: 31 additions & 1 deletion test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef
using MatrixAlgebraKit: diagview
using MatrixAlgebraKit: TruncationKeepAbove, diagview

@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
Expand Down Expand Up @@ -115,3 +115,33 @@ end
end
end
end

@testset "svd_trunc! mix maxrank and tol for T = $T" for T in
(Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
if LinearAlgebra.LAPACK.version() < v"3.12.0"
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
else
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
LAPACK_Jacobi())
end
m = 4
@testset "algorithm $alg" for alg in algs
U = qr_compact(randn(rng, T, m, m))[1]
S = Diagonal([0.9, 0.3, 0.1, 0.01])
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
A = U * S * Vᴴ

for trunc_fun in ((rtol, maxrank) -> (; rtol, maxrank),
(rtol, maxrank) -> truncrank(maxrank) & TruncationKeepAbove(0, rtol))
U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 1))
@test length(S1.diag) == 1
@test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T)))

U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 3))
@test length(S2.diag) == 2
@test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T)))
end
end
end
29 changes: 29 additions & 0 deletions test/truncate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove,
TruncationStrategy

@testset "truncate" begin
trunc = @constinferred TruncationStrategy()
@test trunc isa NoTruncation

trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3)
@test trunc isa TruncationKeepAbove
@test trunc == TruncationKeepAbove(1e-2, 1e-3)
@test trunc.atol == 1e-2
@test trunc.rtol == 1e-3

trunc = @constinferred TruncationStrategy(; maxrank=10)
@test trunc isa TruncationKeepSorted
@test trunc == truncrank(10)
@test trunc.howmany == 10
@test trunc.sortby == abs
@test trunc.rev == true

trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10)
@test trunc isa TruncationIntersection
@test trunc == truncrank(10) & TruncationKeepAbove(1e-2, 1e-3)
@test trunc.components[1] == truncrank(10)
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)
end
Loading