-
Notifications
You must be signed in to change notification settings - Fork 5
Refactor orthogonalization and nullspace interface #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 46 commits
6d43f15
9c558d9
651dca3
d4da1ec
136cf62
545ef5d
1d06840
b68d8c8
4c31a36
18b5a94
3b98df6
1541f9f
08944ec
67dc55a
3fc6cbb
fdbb86c
0559aee
86f0c8d
84537c7
e8d8b12
01ee394
c0fda8c
3bc0081
d3e155b
437ed8c
f6266ea
6bc158a
c7942c0
f547ae0
b3407a2
916fc28
14bc3c7
d45cfa9
2b62dfd
adc36bb
9c916af
185d939
f89a318
878b428
1744416
54dda05
15ff6ae
aff5766
afa64bc
fb60fe3
a2b7e9a
f044228
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ using MatrixAlgebraKit | |
| using MatrixAlgebraKit: @algdef, Algorithm, check_input | ||
| using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! | ||
| using MatrixAlgebraKit: diagview, sign_safe | ||
| using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm | ||
| using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm | ||
| using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm | ||
| import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! | ||
| import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx! | ||
|
|
@@ -161,7 +161,9 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix) | |
| return A, B | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tuple{TU, TS}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TU <: ROCArray, TS} | ||
| function MatrixAlgebraKit.truncate( | ||
| ::typeof(left_null!), US::Tuple{TU, TS}, strategy::TruncationStrategy | ||
| ) where {TU <: ROCMatrix, TS} | ||
| # TODO: avoid allocation? | ||
| U, S = US | ||
| extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2)))) | ||
|
|
@@ -170,5 +172,32 @@ function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tu | |
| Utrunc = U[:, trunc_cols] | ||
| return Utrunc, ind | ||
| end | ||
| function MatrixAlgebraKit.truncate( | ||
| ::typeof(right_null!), SVᴴ::Tuple{TS, TVᴴ}, strategy::TruncationStrategy | ||
| ) where {TS, TVᴴ <: ROCMatrix} | ||
| # TODO: avoid allocation? | ||
| S, Vᴴ = SVᴴ | ||
| extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1)))) | ||
| ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) | ||
| trunc_rows = collect(1:size(Vᴴ, 1))[ind] | ||
| Vᴴtrunc = Vᴴ[trunc_rows, :] | ||
| return Vᴴtrunc, ind | ||
| end | ||
|
|
||
| # disambiguate: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are the methods below ambiguous with the methods above? They seem strictly more specific. Or are they ambiguous with methods that specify
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The latter, I special-cased |
||
| function MatrixAlgebraKit.truncate( | ||
| ::typeof(left_null!), (U, S)::Tuple{TU, TS}, ::NoTruncation | ||
| ) where {TU <: ROCMatrix, TS} | ||
| m, n = size(S) | ||
| ind = (n + 1):m | ||
| return U[:, ind], ind | ||
| end | ||
| function MatrixAlgebraKit.truncate( | ||
| ::typeof(right_null!), (S, Vᴴ)::Tuple{TS, TVᴴ}, ::NoTruncation | ||
| ) where {TS, TVᴴ <: ROCMatrix} | ||
| m, n = size(S) | ||
| ind = (m + 1):n | ||
| return Vᴴ[ind, :], ind | ||
| end | ||
|
|
||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.