Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <[email protected]> and contributors"]
version = "0.1.1"
version = "0.1.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
39 changes: 36 additions & 3 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
return NoTruncation()
elseif isnothing(maxrank)
@assert isnothing(rtol) "TODO: rtol"
return trunctol(atol)
atol = @something atol 0
rtol = @something rtol 0
return TruncationKeepAbove(atol, rtol)

Check warning on line 16 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L14-L16

Added lines #L14 - L16 were not covered by tests
else
return truncrank(maxrank)
if isnothing(atol) && isnothing(rtol)
return truncrank(maxrank)

Check warning on line 19 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L19

Added line #L19 was not covered by tests
else
atol = @something atol 0
rtol = @something rtol 0
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
end
end
end

Expand Down Expand Up @@ -82,6 +89,27 @@
"""
truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs)

"""
TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
Compose two truncation strategies, keeping values common between the two strategies.
"""
struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <:
TruncationStrategy
components::T
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationComposition((trunc1, trunc2))
end
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition)
return TruncationComposition((trunc1.components..., trunc2.components...))

Check warning on line 104 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy)
return TruncationComposition((trunc1.components..., trunc2))

Check warning on line 107 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition)
return TruncationComposition((trunc1, trunc2.components...))

Check warning on line 110 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L109-L110

Added lines #L109 - L110 were not covered by tests
end

# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
Expand Down Expand Up @@ -147,6 +175,11 @@
return 1:i
end

function findtruncated(values::AbstractVector, strategy::TruncationComposition)
inds = map(Base.Fix1(findtruncated, values), strategy.components)
return intersect(inds...)
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand Down
27 changes: 27 additions & 0 deletions test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,30 @@ end
end
end
end

@testset "svd_trunc! mix maxrank and tol for T = $T" for T in
(Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
if LinearAlgebra.LAPACK.version() < v"3.12.0"
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
else
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
LAPACK_Jacobi())
end
m = 4
@testset "algorithm $alg" for alg in algs
U = qr_compact(randn(rng, T, m, m))[1]
S = Diagonal([0.9, 0.3, 0.1, 0.01])
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
A = U * S * Vᴴ

U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1))
@test length(S1.diag) == 1
@test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T)))

U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3))
@test length(S2.diag) == 2
@test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T)))
end
end