Skip to content

Commit 40276bf

Browse files
committed
Allow customizing truncate! based on the algorithm
1 parent c46119e commit 40276bf

File tree

5 files changed

+37
-16
lines changed

5 files changed

+37
-16
lines changed

src/implementations/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,5 @@ end
8484

8585
function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
8686
D, V = eig_full!(A, DV, alg.alg)
87-
return truncate!(eig_trunc!, (D, V), alg.trunc)
87+
return truncate!(eig_trunc!, (D, V), alg)
8888
end

src/implementations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,5 @@ end
8686

8787
function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
8888
D, V = eigh_full!(A, DV, alg.alg)
89-
return truncate!(eigh_trunc!, (D, V), alg.trunc)
89+
return truncate!(eigh_trunc!, (D, V), alg)
9090
end

src/implementations/orthnull.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing,
180180
trunc′ = trunc isa TruncationStrategy ? trunc :
181181
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
182182
throw(ArgumentError("Unknown truncation strategy: $trunc"))
183-
return truncate!(left_null!, (U, S), trunc′)
183+
return truncate!(left_null!, (U, S), TruncatedAlgorithm(alg_svd′, trunc′))
184184
else
185185
throw(ArgumentError("`left_null!` received unknown value `kind = $kind`"))
186186
end
@@ -207,7 +207,7 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
207207
trunc′ = trunc isa TruncationStrategy ? trunc :
208208
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
209209
throw(ArgumentError("Unknown truncation strategy: $trunc"))
210-
return truncate!(right_null!, (S, Vᴴ), trunc′)
210+
return truncate!(right_null!, (S, Vᴴ), TruncatedAlgorithm(alg_svd′, trunc′))
211211
else
212212
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
213213
end

src/implementations/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,5 @@ end
170170

171171
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
172172
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
173-
return truncate!(svd_trunc!, USVᴴ′, alg.trunc)
173+
return truncate!(svd_trunc!, USVᴴ′, alg)
174174
end

src/implementations/truncation.jl

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,33 +127,65 @@ function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection)
127127
return TruncationIntersection((trunc1, trunc2.components...))
128128
end
129129

130+
"""
131+
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
132+
133+
Generic wrapper type for algorithms that consist of first using `alg`, followed by a
134+
truncation through `trunc`.
135+
"""
136+
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
137+
alg::A
138+
trunc::T
139+
end
140+
130141
# truncate!
131142
# ---------
132143
# Generic implementation: `findtruncated` followed by indexing
133144
@doc """
134145
truncate!(f, out, strategy::TruncationStrategy)
146+
truncate!(f, out, alg::AbstractAlgorithm)
135147
136148
Generic interface for post-truncating a decomposition, specified in `out`.
137149
""" truncate!
150+
138151
# TODO: should we return a view?
152+
function truncate!(::typeof(svd_trunc!), USVᴴ, alg::TruncatedAlgorithm)
153+
return truncate!(svd_trunc!, USVᴴ, alg.trunc)
154+
end
139155
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
140156
ind = findtruncated(diagview(S), strategy)
141157
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
142158
end
159+
160+
function truncate!(::typeof(eig_trunc!), DV, alg::TruncatedAlgorithm)
161+
return truncate!(eig_trunc!, DV, alg.trunc)
162+
end
143163
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
144164
ind = findtruncated(diagview(D), strategy)
145165
return Diagonal(diagview(D)[ind]), V[:, ind]
146166
end
167+
168+
function truncate!(::typeof(eigh_trunc!), DV, alg::TruncatedAlgorithm)
169+
return truncate!(eigh_trunc!, DV, alg.trunc)
170+
end
147171
function truncate!(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy)
148172
ind = findtruncated(diagview(D), strategy)
149173
return Diagonal(diagview(D)[ind]), V[:, ind]
150174
end
175+
176+
function truncate!(::typeof(left_null!), US, alg::TruncatedAlgorithm)
177+
return truncate!(left_null!, US, alg.trunc)
178+
end
151179
function truncate!(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
152180
# TODO: avoid allocation?
153181
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
154182
ind = findtruncated(extended_S, strategy)
155183
return U[:, ind]
156184
end
185+
186+
function truncate!(::typeof(right_null!), SVᴴ, alg::TruncatedAlgorithm)
187+
return truncate!(right_null!, SVᴴ, alg.trunc)
188+
end
157189
function truncate!(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
158190
# TODO: avoid allocation?
159191
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1))))
@@ -196,14 +228,3 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
196228
inds = map(Base.Fix1(findtruncated, values), strategy.components)
197229
return intersect(inds...)
198230
end
199-
200-
"""
201-
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
202-
203-
Generic wrapper type for algorithms that consist of first using `alg`, followed by a
204-
truncation through `trunc`.
205-
"""
206-
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
207-
alg::A
208-
trunc::T
209-
end

0 commit comments

Comments
 (0)