Skip to content

Commit 4e3c3d5

Browse files
committed
added column pivoted QR solve to math::linalg::non_distributed
1 parent 3b1e3af commit 4e3c3d5

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

src/TiledArray/math/linalg/non-distributed/qr.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ auto householder_qr(const ArrayV& V, TiledRange q_trange = TiledRange(),
3434
}
3535
}
3636

37+
template <typename ArrayA, typename ArrayB, typename T = ArrayB::numeric_type>
38+
auto qr_solve(const ArrayA& A, const ArrayB& B,
39+
const TiledArray::detail::real_t<T> cond = 1e8,
40+
TiledRange x_trange = TiledRange()) {
41+
(void)detail::array_traits<ArrayB>{};
42+
auto& world = B.world();
43+
auto A_eig = detail::make_matrix(A);
44+
auto B_eig = detail::make_matrix(B);
45+
TA_LAPACK_ON_RANK_ZERO(qr_solve, world, A_eig, B_eig, cond);
46+
world.gop.broadcast_serializable(B_eig, 0);
47+
if (x_trange.rank() == 0) x_trange = B.trange();
48+
auto X = eigen_to_array<ArrayB>(world, x_trange, B_eig);
49+
return X;
50+
}
51+
3752
} // namespace TiledArray::math::linalg::non_distributed
3853

3954
#endif

src/TiledArray/math/linalg/rank-local.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,22 @@ void cholesky_lsolve(Op transpose, Matrix<T>& A, Matrix<T>& X) {
112112
TA_LAPACK(trtrs, uplo, transpose, diag, n, nrhs, a, lda, b, ldb);
113113
}
114114

115+
template <typename T>
116+
void qr_solve(Matrix<T>& A, Matrix<T>& B,
117+
const TiledArray::detail::real_t<T> cond) {
118+
integer m = A.rows();
119+
integer n = A.cols();
120+
integer nrhs = B.cols();
121+
T* a = A.data();
122+
integer lda = A.rows();
123+
T* b = B.data();
124+
integer ldb = B.rows();
125+
std::vector<integer> jpiv(n);
126+
const TiledArray::detail::real_t<T> rcond = 1 / cond;
127+
integer rank = -1;
128+
TA_LAPACK(gelsy, m, n, nrhs, a, lda, b, ldb, jpiv.data(), rcond, &rank);
129+
}
130+
115131
template <typename T>
116132
void heig(Matrix<T>& A, std::vector<TiledArray::detail::real_t<T>>& W) {
117133
auto jobz = lapack::Job::Vec;
@@ -250,7 +266,7 @@ void householder_qr(Matrix<T>& V, Matrix<T>& R) {
250266
lapack::orgqr(m, n, k, v, ldv, tau.data());
251267
}
252268

253-
#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \
269+
#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR, DOUBLE) \
254270
template void cholesky(MATRIX&); \
255271
template void cholesky_linv(MATRIX&); \
256272
template void cholesky_solve(MATRIX&, MATRIX&); \
@@ -261,11 +277,12 @@ void householder_qr(Matrix<T>& V, Matrix<T>& R) {
261277
template void lu_solve(MATRIX&, MATRIX&); \
262278
template void lu_inv(MATRIX&); \
263279
template void householder_qr<true>(MATRIX&, MATRIX&); \
264-
template void householder_qr<false>(MATRIX&, MATRIX&);
280+
template void householder_qr<false>(MATRIX&, MATRIX&); \
281+
template void qr_solve(MATRIX&, MATRIX&, DOUBLE)
265282

266-
TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>);
267-
TA_LAPACK_EXPLICIT(Matrix<float>, std::vector<float>);
268-
TA_LAPACK_EXPLICIT(Matrix<std::complex<double>>, std::vector<double>);
269-
TA_LAPACK_EXPLICIT(Matrix<std::complex<float>>, std::vector<float>);
283+
TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>, double );
284+
TA_LAPACK_EXPLICIT(Matrix<float>, std::vector<float>, float);
285+
TA_LAPACK_EXPLICIT(Matrix<std::complex<double>>, std::vector<double>, double);
286+
TA_LAPACK_EXPLICIT(Matrix<std::complex<float>>, std::vector<float>, float);
270287

271288
} // namespace TiledArray::math::linalg::rank_local

src/TiledArray/math/linalg/rank-local.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ void cholesky_solve(Matrix<T> &A, Matrix<T> &X);
4141
template <typename T>
4242
void cholesky_lsolve(Op transpose, Matrix<T> &A, Matrix<T> &X);
4343

44+
template <typename T>
45+
void qr_solve(Matrix<T> &A, Matrix<T> &B,
46+
const TiledArray::detail::real_t<T> cond = 1e8);
47+
4448
template <typename T>
4549
void heig(Matrix<T> &A, std::vector<TiledArray::detail::real_t<T>> &W);
4650

0 commit comments

Comments
 (0)