Skip to content

Commit f166200

Browse files
authored
Add truncerr truncation strategy (#59)
1 parent d4f0c86 commit f166200

File tree

5 files changed

+310
-111
lines changed

5 files changed

+310
-111
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.3"
4+
version = "0.3.4"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

docs/src/reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Reference
22

33
```@autodocs
4-
Modules = [TensorAlgebra]
4+
Modules = [TensorAlgebra, TensorAlgebra.MatrixAlgebra]
55
```

src/MatrixAlgebra.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ export eigen,
1717
svd,
1818
svd!,
1919
svdvals,
20-
svdvals!
20+
svdvals!,
21+
truncerr
2122

2223
using LinearAlgebra: LinearAlgebra
2324
using MatrixAlgebraKit
@@ -133,4 +134,43 @@ for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, :
133134
end
134135
end
135136

137+
using MatrixAlgebraKit: MatrixAlgebraKit, TruncationStrategy
138+
139+
struct TruncationError{T<:Real} <: TruncationStrategy
140+
atol::T
141+
rtol::T
142+
p::Int
143+
end
144+
145+
"""
146+
truncerr(; atol::Real=0, rtol::Real=0, p::Int=2)
147+
148+
Create a truncation strategy for truncating such that the error in the factorization
149+
is smaller than `max(atol, rtol * norm)`, where the error is determined using the `p`-norm.
150+
"""
151+
function truncerr(; atol::Real=0, rtol::Real=0, p::Int=2)
152+
return TruncationError(promote(atol, rtol)..., p)
153+
end
154+
155+
function MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationError)
156+
Base.require_one_based_indexing(values)
157+
issorted(values; rev=true) || error("Not sorted.")
158+
# norm(values, p) ^ p
159+
normᵖ = sum(Base.Fix2(^, strategy.p) abs, values)
160+
ϵᵖ = max(strategy.atol ^ strategy.p, strategy.rtol ^ strategy.p * normᵖ)
161+
if ϵᵖ normᵖ
162+
return Base.OneTo(0)
163+
end
164+
truncerrᵖ = zero(real(eltype(values)))
165+
rank = length(values)
166+
for i in reverse(eachindex(values))
167+
truncerrᵖ += abs(values[i]) ^ strategy.p
168+
if truncerrᵖ ϵᵖ
169+
rank = i
170+
break
171+
end
172+
end
173+
return Base.OneTo(rank)
174+
end
175+
136176
end

test/test_exports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ using TensorAlgebra: TensorAlgebra
4545
:svd!,
4646
:svdvals,
4747
:svdvals!,
48+
:truncerr,
4849
]
4950
@test issetequal(names(TensorAlgebra.MatrixAlgebra), exports)
5051
end

0 commit comments

Comments
 (0)