Skip to content

Commit 0d3f2d0

Browse files
committed
Add geqrf
1 parent dba1900 commit 0d3f2d0

File tree

5 files changed

+273
-0
lines changed

5 files changed

+273
-0
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,97 @@ struct lapack_getri<T, DEVICE_GPU> {
121121
};
122122

123123

124+
template <typename T>
125+
struct lapack_getrf_inplace<T, DEVICE_GPU> {
126+
void operator(){
127+
const int m,
128+
const int n,
129+
T *A,
130+
const int lda)
131+
{
132+
const int k = std::min(m, n);
133+
134+
// 1. Allocate tau on device
135+
T *d_tau;
136+
cudaErrcheck(cudaMalloc(&d_tau, sizeof(T) * k));
137+
138+
// 2. Query for workspace size
139+
int lwork = 0;
140+
int *d_info;
141+
cudaErrcheck(cudaMalloc(&d_info, sizeof(int)));
142+
143+
// geqrf: workspace query
144+
cuSolverConnector::geqrf(cusolverH, m, n, d_A, lda, d_tau, nullptr, -1, d_info);
145+
// Note: cuSOLVER uses nullptr for query, result returned via lwork
146+
// But we need to call it with real pointer to get lwork
147+
T work_query;
148+
cuSolverConnector::geqrf(cusolverH, m, n, d_A, lda, d_tau, &work_query, -1, d_info);
149+
150+
// In practice, we use helper function to get lwork
151+
// Or use magma for better interface
152+
// Let's assume we have a way to get lwork
153+
// For now, do a dummy call to get it
154+
size_t workspaceInBytes = 0;
155+
cusolverErrcheck(cusolverDnXgeqrf_bufferSize(
156+
cusolverH, m, n,
157+
getCudaDataType<T>::type, d_A, lda,
158+
getCudaDataType<T>::type, // for tau
159+
CUDA_R_32F, // numerical precision
160+
CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
161+
162+
lwork = static_cast<int>(workspaceInBytes / sizeof(T));
163+
164+
// Allocate workspace
165+
T *d_work;
166+
cudaErrcheck(cudaMalloc(&d_work, sizeof(T) * lwork));
167+
168+
// 3. Perform geqrf
169+
cusolverErrcheck(cusolverDnXgeqrf(
170+
cusolverH, m, n,
171+
getCudaDataType<T>::type, d_A, lda,
172+
d_tau,
173+
getCudaDataType<T>::type,
174+
d_work, lwork * sizeof(T),
175+
d_info));
176+
177+
int info;
178+
cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
179+
if (info != 0) {
180+
throw std::runtime_error("cuSOLVER geqrf failed with info = " + std::to_string(info));
181+
}
182+
183+
// 4. Generate Q using orgqr
184+
// Query workspace for orgqr
185+
cusolverErrcheck(cusolverDnXorgqr_bufferSize(
186+
cusolverH, m, n, k,
187+
getCudaDataType<T>::type, d_A, lda,
188+
getCudaDataType<T>::type, d_tau,
189+
CUDA_R_32F,
190+
CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
191+
192+
lwork = static_cast<int>(workspaceInBytes / sizeof(T));
193+
cudaErrcheck(cudaRealloc(&d_work, sizeof(T) * lwork)); // or realloc
194+
195+
// orgqr: generate Q
196+
cusolverErrcheck(cusolverDnXorgqr(
197+
cusolverH, m, n, k,
198+
getCudaDataType<T>::type, d_A, lda,
199+
getCudaDataType<T>::type, d_tau,
200+
d_work, lwork * sizeof(T),
201+
d_info));
202+
203+
cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
204+
if (info != 0) {
205+
throw std::runtime_error("cuSOLVER orgqr failed with info = " + std::to_string(info));
206+
}
207+
208+
// Clean up
209+
cudaErrcheck(cudaFree(d_tau));
210+
cudaErrcheck(cudaFree(d_work));
211+
cudaErrcheck(cudaFree(d_info));
212+
}
213+
};
214+
124215
// --- 2. Linear System Solvers ---
125216
template <typename T>
126217
struct lapack_getrs<T, DEVICE_GPU> {

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,64 @@ struct lapack_getri<T, DEVICE_CPU> {
110110
}
111111
};
112112

113+
template <typename T>
114+
struct lapack_geqrf_inplace<T, DEVICE_CPU> {
115+
void operator()(
116+
const int m,
117+
const int n,
118+
T *A,
119+
const int lda)
120+
{
121+
// Tensor or vector?
122+
// 1. tau for storing the Householder reflectors
123+
// tau should be dimension min(m, n)
124+
int k = std::min(m, n);
125+
Tensor tau(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {k});
126+
tau.zero();
127+
128+
int info = 0;
129+
130+
// 2. query for workspace size
131+
int lwork = -1;
132+
T work_query;
133+
lapackConnector::geqrf(m, n, A, lda, tau.data<T>(), &work_query, lwork, info);
134+
if (info != 0) {
135+
throw std::runtime_error("geqrf workspace query failed with info = " + std::to_string(info));
136+
}
137+
// allocate workspace
138+
lwork = static_cast<int>(get_real(work_query));
139+
Tensor work(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {lwork});
140+
work.zero();
141+
142+
// 3. perform QR decomposition
143+
// and A is overwritten with upper R.
144+
// Lower A + tau => Q
145+
lapackConnector::geqrf(m, n, A, lda, tau.data<T>(), work.data<T>(), lwork, info);
146+
if (info != 0) {
147+
throw std::runtime_error("geqrf failed with info = " + std::to_string(info));
148+
}
149+
150+
// 4. use orgqr to compute Q
151+
// workspace query
152+
lwork = -1;
153+
lapackConnector::orgqr(m, n, k, A, lda, tau.data<T>(), &work_query, lwork, info);
154+
if (info != 0) {
155+
throw std::runtime_error("orgqr workspace query failed with info = " + std::to_string(info));
156+
}
157+
// allocate workspace
158+
lwork = static_cast<int>(get_real(work_query));
159+
work.resize({lwork});
160+
161+
// compute Q
162+
lapackConnector::orgqr(m, n, k, A, lda, tau.data<T>(), work.data<T>(), lwork, info);
163+
if (info != 0) {
164+
throw std::runtime_error("orgqr failed with info = " + std::to_string(info));
165+
}
166+
167+
// now, A should be overwritten with Q, columns orthogonal
168+
169+
}
170+
};
113171

114172
// --- 2. Linear System Solvers ---
115173
template <typename T>

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,41 @@ struct lapack_getri {
6363
const int& lwork);
6464
};
6565

66+
// This is QR factorization in-place
67+
// that will change input Mat A to orthogonal/unitary matrix Q
68+
template <typename T, typename Device>
69+
struct lapack_geqrf_inplace {
70+
void operator()(
71+
const int m,
72+
const int n,
73+
T *A,
74+
const int lda);
75+
};
76+
77+
// This is QR factorization
78+
// where [in]Mat will be kept and the results are stored in separate matrix Q
79+
// template <typename T, typename Device>
80+
// struct lapack_geqrf{
81+
// /**
82+
// * Perform QR factorization of a matrix using LAPACK's geqrf function.
83+
// *
84+
// * @param m The number of rows in the matrix.
85+
// * @param n The number of columns in the matrix.
86+
// * @param Mat The matrix to be factorized.
87+
// * On exit, the upper triangle contains the upper triangular matrix R,
88+
// * and the elements below the diagonal, with the array TAU, represent
89+
// * the unitary matrix Q as a product of min(m,n) elementary reflectors.
90+
// * @param lda The leading dimension of the matrix.
91+
// * @param tau Array of size min(m,n) containing the Householder reflectors.
92+
// */
93+
// void operator()(
94+
// const int m,
95+
// const int n,
96+
// T *Mat,
97+
// const int lda,
98+
// T *tau);
99+
// };
100+
66101

67102
// --- 2. Linear System Solvers ---
68103
template <typename T, typename Device>

source/source_base/module_container/base/third_party/cusolver.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,34 @@ void getrs(cusolverDnHandle_t& cusolver_handle, const char& trans, const int& n,
11361136
cudaErrcheck(cudaFree(d_info));
11371137
}
11381138

1139+
// QR decomposition
1140+
// geqrf, orgqr
1141+
// Note:
1142+
// there are two cusolver geqrf
1143+
// one is cusolverDn<t>geqrf
1144+
// one is cusolverDnXgeqrf
1145+
// which one is better?
1146+
static inline
1147+
void geqrf(cusolverDnHandle_t& cusolver_handle, const int m, const int n, std::complex<float>* A, const int lda, std::complex<float>* tau)
1148+
{
1149+
// first allocate memory for workspace
1150+
int lwork = 0;
1151+
cusolverErrcheck(cusolverDnCgeqrf_bufferSize(cusolver_handle, m, n, reinterpret_cast<cuComplex*>(A), lda, &lwork));
1152+
1153+
std::complex<float>* d_work = nullptr;
1154+
cudaErrcheck(cudaMalloc((void**)&d_work, lwork * sizeof(std::complex<float>)));
1155+
1156+
// compute QR decomposition
1157+
cusolverErrcheck(cusolverDnCgeqrf(cusolver_handle, m, n, reinterpret_cast<cuComplex*>(A), lda, reinterpret_cast<cuComplex*>(tau), d_work, lwork, d_info));
1158+
1159+
cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
1160+
if (h_info != 0) {
1161+
throw std::runtime_error("geqrf: failed to compute QR decomposition");
1162+
}
1163+
1164+
cudaErrcheck(cudaFree(d_work));
1165+
}
1166+
11391167
} // namespace cuSolverConnector
11401168
} // namespace container
11411169

source/source_base/module_container/base/third_party/lapack.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ void dgetri_(const int* n, double* A, const int* lda, const int* ipiv, double* w
181181
void cgetri_(const int* n, std::complex<float>* A, const int* lda, const int* ipiv, std::complex<float>* work, const int* lwork, int* info);
182182
void zgetri_(const int* n, std::complex<double>* A, const int* lda, const int* ipiv, std::complex<double>* work, const int* lwork, int* info);
183183

184+
<<<<<<< Updated upstream
184185
// Solve linear system using LU factorization
185186
void sgetrs_(const char* trans, const int* n, const int* nrhs,
186187
const float* A, const int* lda, const int* ipiv,
@@ -194,6 +195,23 @@ void cgetrs_(const char* trans, const int* n, const int* nrhs,
194195
void zgetrs_(const char* trans, const int* n, const int* nrhs,
195196
const std::complex<double>* A, const int* lda, const int* ipiv,
196197
std::complex<double>* B, const int* ldb, int* info);
198+
=======
199+
void sgetrs_(const char* trans, const int* n, const int* nrhs, const float* A, const int* lda, const int* ipiv, float* B, const int* ldb, int* info);
200+
void dgetrs_(const char* trans, const int* n, const int* nrhs, const double* A, const int* lda, const int* ipiv, double* B, const int* ldb, int* info);
201+
void cgetrs_(const char* trans, const int* n, const int* nrhs, const std::complex<float>* A, const int* lda, const int* ipiv, std::complex<float>* B, const int* ldb, int* info);
202+
void zgetrs_(const char* trans, const int* n, const int* nrhs, const std::complex<double>* A, const int* lda, const int* ipiv, std::complex<double>* B, const int* ldb, int* info);
203+
204+
void sgeqrf_(const int* m, const int* n, float* A, const int* lda, float* tau, float *work, const int* lwork, int* info);
205+
void dgeqrf_(const int* m, const int* n, double* A, const int* lda, double* tau, double *work, const int* lwork, int* info);
206+
void cgeqrf_(const int* m, const int* n, std::complex<float>* A, const int* lda, std::complex<float>* tau, std::complex<float> *work, const int* lwork, int* info);
207+
void zgeqrf_(const int* m, const int* n, std::complex<double>* A, const int* lda, std::complex<double>* tau, std::complex<double> *work, const int* lwork, int* info);
208+
209+
void sorgqr_(const int* m, const int* n, const int* k, float* A, const int* lda, const float* tau, float* work, const int* lwork, int* info);
210+
void dorgqr_(const int* m, const int* n, const int* k, double* A, const int* lda, const double* tau, double* work, const int* lwork, int* info);
211+
void cungqr_(const int* m, const int* n, const int* k, std::complex<float>* A, const int* lda, const std::complex<float>* tau, std::complex<float> *work, const int* lwork, int* info);
212+
void zunqrf_(const int* m, const int* n, const int* k, std::complex<double>* A, const int* lda, const std::complex<double>* tau, std::complex<double> *work, const int* lwork, int* info);
213+
214+
>>>>>>> Stashed changes
197215
}
198216

199217
// Class LapackConnector provide the connector to fortran lapack routine.
@@ -535,6 +553,49 @@ void getrs(const char& trans, const int n, const int nrhs, std::complex<double>*
535553
zgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info);
536554
}
537555

556+
// LAPACK routines for QR decomposition
557+
static inline
558+
void geqrf(const int m, const int n, float* A, const int lda, float* tau, float* work, const int lwork, int& info)
559+
{
560+
sgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
561+
}
562+
static inline
563+
void geqrf(const int m, const int n, double* A, const int lda, double* tau, double* work, const int lwork, int& info)
564+
{
565+
dgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
566+
}
567+
static inline
568+
void geqrf(const int m, const int n, std::complex<float>* A, const int lda, std::complex<float>* tau, std::complex<float>* work, const int lwork, int& info)
569+
{
570+
cgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
571+
}
572+
static inline
573+
void geqrf(const int m, const int n, std::complex<double>* A, const int lda, std::complex<double>* tau, std::complex<double>* work, const int lwork, int& info)
574+
{
575+
zgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info);
576+
}
577+
// these routines generate the orthogonal matrix Q from the QR decomposition
578+
static inline
579+
void orgqr(const int m, const int n, const int k, float* A, const int lda, const float* tau, float* work, const int lwork, int& info)
580+
{
581+
sorgqr_(&m, &n, &k, A, &lda, tau, work, &lwork, &info);
582+
}
583+
static inline
584+
void orgqr(const int m, const int n, const int k, double* A, const int lda, const double* tau, double* work, const int lwork, int& info)
585+
{
586+
dorgqr_(&m, &n, &k, A, &lda, tau, work, &lwork, &info);
587+
}
588+
static inline
589+
void orgqr(const int m, const int n, const int k, std::complex<float>* A, const int lda, const std::complex<float>* tau, std::complex<float>* work, const int lwork, int& info)
590+
{
591+
cungqr_(&m, &n, &k, A, &lda, tau, work, &lwork, &info);
592+
}
593+
static inline
594+
void orgqr(const int m, const int n, const int k, std::complex<double>* A, const int lda, const std::complex<double>* tau, std::complex<double>* work, const int lwork, int& info)
595+
{
596+
zunqrf_(&m, &n, &k, A, &lda, tau, work, &lwork, &info);
597+
}
598+
538599
} // namespace lapackConnector
539600
} // namespace container
540601

0 commit comments

Comments
 (0)