Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ,
return PWᴴ, right_polar_pullback
end

end
end
2 changes: 1 addition & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,4 @@ macro check_size(x, sz, size=:size)
string($sz)
szx == $sz || throw(DimensionMismatch($err))
end)
end
end
2 changes: 1 addition & 1 deletion src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,4 @@ function right_null!(A::AbstractMatrix, Nᴴ; kwargs...)
else
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
end
end
end
33 changes: 30 additions & 3 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
return NoTruncation()
elseif isnothing(maxrank)
@assert isnothing(rtol) "TODO: rtol"
return trunctol(atol)
atol = @something atol 0
rtol = @something rtol 0
return TruncationKeepAbove(atol, rtol)

Check warning on line 16 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L14-L16

Added lines #L14 - L16 were not covered by tests
else
return truncrank(maxrank)
if isnothing(atol) && isnothing(rtol)
return truncrank(maxrank)

Check warning on line 19 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L19

Added line #L19 was not covered by tests
else
atol = @something atol 0
rtol = @something rtol 0
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
end
end
end

Expand Down Expand Up @@ -82,6 +89,20 @@
"""
truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs)

"""
TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy)

Compose two truncation strategies, keeping values common between the two strategies.
"""
struct TruncationComposition{T1<:TruncationStrategy,T2<:TruncationStrategy} <:
TruncationStrategy
trunc1::T1
trunc2::T2
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationComposition(trunc1, trunc2)
end

# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
Expand Down Expand Up @@ -147,6 +168,12 @@
return 1:i
end

function findtruncated(values::AbstractVector, strategy::TruncationComposition)
ind1 = findtruncated(values, strategy.trunc1)
ind2 = findtruncated(values, strategy.trunc2)
return ind1 ∩ ind2
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand Down
2 changes: 1 addition & 1 deletion src/interface/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,4 @@ function right_null!(A::AbstractMatrix; kwargs...)
end
function right_null(A::AbstractMatrix; kwargs...)
return right_null!(copy_input(right_null, A); kwargs...)
end
end
2 changes: 1 addition & 1 deletion src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ function right_polar_pullback!(ΔA::AbstractMatrix, PWᴴ, ΔPWᴴ)
ΔA .+= PΔWᴴ
end
return ΔA
end
end
2 changes: 1 addition & 1 deletion test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,4 @@ end
test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ,
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
end
end
end
2 changes: 1 addition & 1 deletion test/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,4 @@ end
end
end
end
end
end
26 changes: 26 additions & 0 deletions test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,29 @@ end
end
end
end

@testset "svd_trunc! mix maxrank and tol for T = $T" for T in (Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
if LinearAlgebra.LAPACK.version() < v"3.12.0"
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
else
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
LAPACK_Jacobi())
end
m = 4
@testset "algorithm $alg" for alg in algs
U = qr_compact(randn(rng, T, m, m))[1]
S = Diagonal([0.9, 0.3, 0.1, 0.01])
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
A = U * S * Vᴴ

U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1))
@test length(S1.diag) == 1
@test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T)))

U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3))
@test length(S2.diag) == 2
@test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T)))
end
end