Skip to content

Commit 3871df1

Browse files
authored
TruncationStrategy types and constructors: consistency in names and implementations (#56)
1 parent 99e4611 commit 3871df1

File tree

22 files changed

+218
-214
lines changed

22 files changed

+218
-214
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MatrixAlgebraKit"
22
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
33
authors = ["Jutho <[email protected]> and contributors"]
4-
version = "0.3.2"
4+
version = "0.4.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

docs/src/dev_interface.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ MatrixAlgebraKit.jl provides a developer interface for specifying custom algorit
1111
MatrixAlgebraKit.default_algorithm
1212
MatrixAlgebraKit.select_algorithm
1313
MatrixAlgebraKit.findtruncated
14-
MatrixAlgebraKit.findtruncated_sorted
14+
MatrixAlgebraKit.findtruncated_svd
1515
```

docs/src/user_interface/truncations.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Currently, truncations are supported through the following different methods:
1111
notrunc
1212
truncrank
1313
trunctol
14-
truncabove
14+
truncfilter
1515
truncerror
1616
```
1717

@@ -20,6 +20,6 @@ For example, truncating to a maximal dimension `10`, and discarding all values b
2020

2121
```julia
2222
maxdim = 10
23-
tol = 1e-6
24-
combined_trunc = truncrank(maxdim) & trunctol(tol)
23+
atol = 1e-6
24+
combined_trunc = truncrank(maxdim) & trunctol(; atol)
2525
```

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: LQViaTransposedQR
7+
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
@@ -40,4 +40,9 @@ _gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwar
4040
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...)
4141
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...)
4242
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...)
43+
44+
function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue)
45+
return MatrixAlgebraKit.findtruncated(values, strategy)
46+
end
47+
4348
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: LQViaTransposedQR
7+
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
@@ -44,4 +44,8 @@ _gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::S
4444
_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...)
4545
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...)
4646

47+
function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue)
48+
return MatrixAlgebraKit.findtruncated(values, strategy)
49+
end
50+
4751
end

src/MatrixAlgebraKit.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,21 @@ export left_polar!, right_polar!
2828
export left_orth, right_orth, left_null, right_null
2929
export left_orth!, right_orth!, left_null!, right_null!
3030

31-
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
32-
LAPACK_Simple, LAPACK_Expert,
33-
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
34-
LAPACK_DivideAndConquer, LAPACK_Jacobi,
35-
LQViaTransposedQR,
36-
CUSOLVER_Simple,
37-
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
38-
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection,
39-
DiagonalAlgorithm
40-
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered, truncerror
31+
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi
32+
export LQViaTransposedQR
33+
export DiagonalAlgorithm
34+
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer
35+
export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi,
36+
ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection
4137

42-
VERSION >= v"1.11.0-DEV.469" &&
43-
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted,
38+
export notrunc, truncrank, trunctol, truncerror, truncfilter
39+
40+
@static if VERSION >= v"1.11.0-DEV.469"
41+
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_svd,
4442
:select_algorithm))
43+
eval(Expr(:public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue,
44+
:TruncationByError, :TruncationIntersection))
45+
end
4546

4647
include("common/defaults.jl")
4748
include("common/initialization.jl")

src/algorithms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,16 @@ based on the `strategy`. The output should be a collection of indices specifying
168168
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
169169
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
170170
values are sorted. For a version that assumes the values are reverse sorted (which is the
171-
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
171+
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_svd`](@ref).
172172
""" findtruncated
173173

174174
@doc """
175-
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
175+
MatrixAlgebraKit.findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy)
176176
177177
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are real and
178178
sorted in descending order, as typically obtained by the SVD. This assumption is not
179179
checked, and this is used in the default implementation of [`svd_trunc!`](@ref).
180-
""" findtruncated_sorted
180+
""" findtruncated_svd
181181

182182
"""
183183
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

src/implementations/orthnull.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,11 @@ end
203203
# --------------------------------
204204
function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing)
205205
if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol)
206-
return NoTruncation()
206+
return notrunc()
207207
end
208208
atol = @something atol 0
209209
rtol = @something rtol 0
210-
trunc = TruncationKeepBelow(atol, rtol)
210+
trunc = trunctol(; atol, rtol, keep_below=true)
211211
return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc
212212
end
213213

src/implementations/truncation.jl

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# ---------
33
# Generic implementation: `findtruncated` followed by indexing
44
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
5-
ind = findtruncated_sorted(diagview(S), strategy)
5+
ind = findtruncated_svd(diagview(S), strategy)
66
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
77
end
88
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
@@ -28,82 +28,89 @@ end
2828

2929
# findtruncated
3030
# -------------
31+
# Generic fallback
32+
function findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy)
33+
return findtruncated(values, strategy)
34+
end
35+
3136
# specific implementations for finding truncated values
3237
findtruncated(values::AbstractVector, ::NoTruncation) = Colon()
3338

34-
function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
39+
function findtruncated(values::AbstractVector, strategy::TruncationByOrder)
3540
howmany = min(strategy.howmany, length(values))
36-
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
41+
return partialsortperm(values, 1:howmany; strategy.by, strategy.rev)
3742
end
38-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted)
43+
function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder)
44+
strategy.by === abs || return findtruncated(values, strategy)
3945
howmany = min(strategy.howmany, length(values))
40-
return 1:howmany
46+
return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values))
4147
end
4248

43-
# TODO: consider if worth using that values are sorted when filter is `<` or `>`.
44-
function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered)
45-
ind = findall(strategy.filter, values)
46-
return ind
49+
function findtruncated(values::AbstractVector, strategy::TruncationByFilter)
50+
return findall(strategy.filter, values)
4751
end
4852

49-
function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
53+
function findtruncated(values::AbstractVector, strategy::TruncationByValue)
5054
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
51-
return findall((atol) strategy.by, values)
55+
filter = (strategy.keep_below ? (atol) : (atol)) strategy.by
56+
return findtruncated(values, truncfilter(filter))
5257
end
53-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
58+
function findtruncated_svd(values::AbstractVector, strategy::TruncationByValue)
59+
strategy.by === abs || return findtruncated(values, strategy)
5460
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
55-
i = searchsortedfirst(values, atol; by=strategy.by, rev=true)
56-
return i:length(values)
57-
end
58-
59-
function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
60-
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
61-
return findall((atol) strategy.by, values)
62-
end
63-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
64-
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
65-
i = searchsortedlast(values, atol; by=strategy.by, rev=true)
66-
return 1:i
67-
end
68-
69-
function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
70-
inds = map(Base.Fix1(findtruncated, values), strategy.components)
71-
return intersect(inds...)
72-
end
73-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection)
74-
inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components)
75-
return intersect(inds...)
61+
if strategy.keep_below
62+
i = searchsortedfirst(values, atol; by=abs, rev=true)
63+
return i:length(values)
64+
else
65+
i = searchsortedlast(values, atol; by=abs, rev=true)
66+
return 1:i
67+
end
7668
end
7769

78-
function findtruncated(values::AbstractVector, strategy::TruncationError)
70+
function findtruncated(values::AbstractVector, strategy::TruncationByError)
7971
I = sortperm(values; by=abs, rev=true)
80-
I′ = _truncerr_impl(values, I, strategy)
72+
I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p)
8173
return I[I′]
8274
end
83-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationError)
75+
function findtruncated_svd(values::AbstractVector, strategy::TruncationByError)
8476
I = eachindex(values)
85-
I′ = _truncerr_impl(values, I, strategy)
77+
I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p)
8678
return I[I′]
8779
end
88-
function _truncerr_impl(values::AbstractVector, I, strategy::TruncationError)
89-
Nᵖ = sum(Base.Fix2(^, strategy.p) abs, values)
90-
ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ)
80+
function _truncerr_impl(values::AbstractVector, I; atol::Real=0, rtol::Real=0, p::Real=2)
81+
by = Base.Fix2(^, p) abs
82+
Nᵖ = sum(by, values)
83+
ϵᵖ = max(atol^p, rtol^p * Nᵖ)
84+
85+
# fast path to avoid checking all values
9186
ϵᵖ Nᵖ && return Base.OneTo(0)
9287

9388
truncerrᵖ = zero(real(eltype(values)))
9489
rank = length(values)
9590
for i in reverse(I)
96-
truncerrᵖ += abs(values[i])^strategy.p
97-
if truncerrᵖ ϵᵖ
98-
break
99-
else
100-
rank -= 1
101-
end
91+
truncerrᵖ += by(values[i])
92+
truncerrᵖ ϵᵖ && break
93+
rank -= 1
10294
end
95+
10396
return Base.OneTo(rank)
10497
end
10598

106-
# Generic fallback
107-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
108-
return findtruncated(values, strategy)
99+
function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
100+
return mapreduce(Base.Fix1(findtruncated, values), _ind_intersect, strategy.components;
101+
init=trues(length(values)))
109102
end
103+
function findtruncated_svd(values::AbstractVector, strategy::TruncationIntersection)
104+
return mapreduce(Base.Fix1(findtruncated_svd, values), _ind_intersect,
105+
strategy.components; init=trues(length(values)))
106+
end
107+
108+
# when one of the ind selections is a bitvector, have to handle differently
109+
function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector)
110+
result = falses(length(A))
111+
result[B] .= @view A[B]
112+
return result
113+
end
114+
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
115+
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
116+
_ind_intersect(A, B) = intersect(A, B)

0 commit comments

Comments
 (0)