Skip to content

Commit 88a91a7

Browse files
committed
Add getrf for Tensor LAPACK
1 parent 4a83d8f commit 88a91a7

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

source/module_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,26 @@ struct lapack_getrs<T, DEVICE_CPU> {
179179
}
180180
};
181181

182+
183+
template <typename T>
184+
struct lapack_geqrf<T, DEVICE_CPU> {
185+
void operator()(
186+
const int& m,
187+
const int& n,
188+
T* A,
189+
const int& lda,
190+
T* tau,
191+
T* work,
192+
const int& lwork)
193+
{
194+
int info = 0;
195+
lapackConnector::geqrf(m, n, A, lda, tau, work, lwork, info);
196+
if (info != 0) {
197+
throw std::runtime_error("geqrf failed with info = " + std::to_string(info));
198+
}
199+
}
200+
};
201+
182202
template struct set_matrix<float, DEVICE_CPU>;
183203
template struct set_matrix<double, DEVICE_CPU>;
184204
template struct set_matrix<std::complex<float>, DEVICE_CPU>;
@@ -219,5 +239,10 @@ template struct lapack_getrs<double, DEVICE_CPU>;
219239
template struct lapack_getrs<std::complex<float>, DEVICE_CPU>;
220240
template struct lapack_getrs<std::complex<double>, DEVICE_CPU>;
221241

242+
template struct lapack_geqrf<float, DEVICE_CPU>;
243+
template struct lapack_geqrf<double, DEVICE_CPU>;
244+
template struct lapack_geqrf<std::complex<float>, DEVICE_CPU>;
245+
template struct lapack_geqrf<std::complex<double>, DEVICE_CPU>;
246+
222247
} // namespace kernels
223248
} // namespace container

source/module_base/module_container/ATen/kernels/lapack.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,35 @@ struct lapack_getrs {
249249
const int& ldb);
250250
};
251251

252+
253+
// add geqrf wrapper
254+
template <typename T, typename Device>
255+
struct lapack_geqrf {
256+
/**
257+
* @brief Perform QR factorization on a matrix.
258+
*
259+
* The factorization has the form
260+
* A = Q * R,
261+
* where Q is orthogonal and R is upper triangular.
262+
*
263+
* @param m The number of rows of the matrix.
264+
* @param n The number of columns of the matrix.
265+
* @param A Pointer to the matrix data.
266+
* @param lda Leading dimension of the matrix.
267+
* @param tau Pointer to the array of scalar factors of the elementary reflectors.
268+
* @param work Pointer to the workspace array.
269+
* @param lwork The size of the workspace array.
270+
*/
271+
void operator()(
272+
const int& m,
273+
const int& n,
274+
T* A,
275+
const int& lda,
276+
T* tau,
277+
T* work,
278+
const int& lwork);
279+
};
280+
252281
#if defined(__CUDA) || defined(__ROCM)
253282
// TODO: Use C++ singleton to manage the GPU handles
254283
void createGpuSolverHandle(); // create cusolver handle

source/module_base/module_container/base/third_party/lapack.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ void sgetrs_(const char* trans, const int* n, const int* nrhs, const float* A, c
119119
void dgetrs_(const char* trans, const int* n, const int* nrhs, const double* A, const int* lda, const int* ipiv, double* B, const int* ldb, int* info);
120120
void cgetrs_(const char* trans, const int* n, const int* nrhs, const std::complex<float>* A, const int* lda, const int* ipiv, std::complex<float>* B, const int* ldb, int* info);
121121
void zgetrs_(const char* trans, const int* n, const int* nrhs, const std::complex<double>* A, const int* lda, const int* ipiv, std::complex<double>* B, const int* ldb, int* info);
122+
123+
void sgeqrf_(const int* m, const int* n, float* a, const int* lda, float* tau, float* work, const int* lwork, int* info);
124+
void dgeqrf_(const int* m, const int* n, double* a, const int* lda, double* tau, double* work, const int* lwork, int* info);
125+
void cgeqrf_(const int* m, const int* n, std::complex<float>* a, const int* lda, std::complex<float>* tau, std::complex<float>* work, const int* lwork, int* info);
126+
void zgeqrf_(const int* m, const int* n, std::complex<double>* a, const int* lda, std::complex<double>* tau, std::complex<double>* work, const int* lwork, int* info);
122127
}
123128

124129
// Class LapackConnector provide the connector to fortran lapack routine.
@@ -398,6 +403,29 @@ void getrs(const char& trans, const int n, const int nrhs, std::complex<double>*
398403
zgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info);
399404
}
400405

406+
static inline
407+
void geqrf(const int m, const int n, float* A, const int lda, float* tau, float* work, const int lwork, int& info)
408+
{
409+
sgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
410+
}
411+
static inline
412+
void geqrf(const int m, const int n, double* A, const int lda, double* tau, double* work, const int lwork, int& info)
413+
{
414+
dgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
415+
}
416+
static inline
417+
void geqrf(const int m, const int n, std::complex<float>* A, const int lda, std::complex<float>* tau, std::complex<float>* work, const int lwork, int& info)
418+
{
419+
cgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
420+
}
421+
static inline
422+
void geqrf(const int m, const int n, std::complex<double>* A, const int lda, std::complex<double>* tau, std::complex<double>* work, const int lwork, int& info)
423+
{
424+
zgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
425+
}
426+
427+
428+
401429
} // namespace lapackConnector
402430
} // namespace container
403431

0 commit comments

Comments
 (0)