@@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
1111 if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
1212 return NoTruncation()
1313 elseif isnothing(maxrank)
14- @assert isnothing(rtol) " TODO: rtol"
15- return trunctol(atol)
14+ atol = @something atol 0
15+ rtol = @something rtol 0
16+ return TruncationKeepAbove(atol, rtol)
1617 else
17- return truncrank(maxrank)
18+ if isnothing(atol) && isnothing(rtol)
19+ return truncrank(maxrank)
20+ else
21+ atol = @something atol 0
22+ rtol = @something rtol 0
23+ return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
24+ end
1825 end
1926end
2027
@@ -82,6 +89,28 @@ Truncation strategy to discard the values that are larger than `atol` in absolut
8289"""
8390truncabove(atol) = TruncationKeepFiltered(≤ (atol) ∘ abs)
8491
92+ """
93+ TruncationIntersection(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
94+
95+ Compose two truncation strategies, keeping values common between the two strategies.
96+ """
97+ struct TruncationIntersection{T<: Tuple{Vararg{TruncationStrategy}} } < :
98+ TruncationStrategy
99+ components:: T
100+ end
101+ function Base.:& (trunc1:: TruncationStrategy , trunc2:: TruncationStrategy )
102+ return TruncationIntersection((trunc1, trunc2))
103+ end
104+ function Base.:& (trunc1:: TruncationIntersection , trunc2:: TruncationIntersection )
105+ return TruncationIntersection((trunc1. components... , trunc2. components... ))
106+ end
107+ function Base.:& (trunc1:: TruncationIntersection , trunc2:: TruncationStrategy )
108+ return TruncationIntersection((trunc1. components... , trunc2))
109+ end
110+ function Base.:& (trunc1:: TruncationStrategy , trunc2:: TruncationIntersection )
111+ return TruncationIntersection((trunc1, trunc2. components... ))
112+ end
113+
85114# truncate!
86115# ---------
87116# Generic implementation: `findtruncated` followed by indexing
@@ -147,6 +176,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
147176 return 1 : i
148177end
149178
179+ function findtruncated(values:: AbstractVector , strategy:: TruncationIntersection )
180+ inds = map(Base. Fix1(findtruncated, values), strategy. components)
181+ return intersect(inds... )
182+ end
183+
150184"""
151185 TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
152186
0 commit comments