Skip to content

Commit e2fe179

Browse files
committed
geqrf_inplace with tests
1 parent ec2582c commit e2fe179

File tree

8 files changed

+889
-90
lines changed

8 files changed

+889
-90
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 75 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -122,93 +122,89 @@ struct lapack_getri<T, DEVICE_GPU> {
122122

123123

124124
template <typename T>
125-
struct lapack_getrf_inplace<T, DEVICE_GPU> {
126-
void operator(){
125+
struct lapack_geqrf_inplace<T, DEVICE_GPU> {
126+
void operator()(
127127
const int m,
128128
const int n,
129-
T *A,
129+
T *d_A,
130130
const int lda)
131131
{
132132
const int k = std::min(m, n);
133133

134-
// 1. Allocate tau on device
134+
// Allocate tau on device
135135
T *d_tau;
136136
cudaErrcheck(cudaMalloc(&d_tau, sizeof(T) * k));
137137

138-
// 2. Query for workspace size
139-
int lwork = 0;
140-
int *d_info;
141-
cudaErrcheck(cudaMalloc(&d_info, sizeof(int)));
142-
143-
// geqrf: workspace query
144-
cuSolverConnector::geqrf(cusolverH, m, n, d_A, lda, d_tau, nullptr, -1, d_info);
145-
// Note: cuSOLVER uses nullptr for query, result returned via lwork
146-
// But we need to call it with real pointer to get lwork
147-
T work_query;
148-
cuSolverConnector::geqrf(cusolverH, m, n, d_A, lda, d_tau, &work_query, -1, d_info);
149-
150-
// In practice, we use helper function to get lwork
151-
// Or use magma for better interface
152-
// Let's assume we have a way to get lwork
153-
// For now, do a dummy call to get it
154-
size_t workspaceInBytes = 0;
155-
cusolverErrcheck(cusolverDnXgeqrf_bufferSize(
156-
cusolverH, m, n,
157-
getCudaDataType<T>::type, d_A, lda,
158-
getCudaDataType<T>::type, // for tau
159-
CUDA_R_32F, // numerical precision
160-
CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
161-
162-
lwork = static_cast<int>(workspaceInBytes / sizeof(T));
163-
164-
// Allocate workspace
165-
T *d_work;
166-
cudaErrcheck(cudaMalloc(&d_work, sizeof(T) * lwork));
167-
168-
// 3. Perform geqrf
169-
cusolverErrcheck(cusolverDnXgeqrf(
170-
cusolverH, m, n,
171-
getCudaDataType<T>::type, d_A, lda,
172-
d_tau,
173-
getCudaDataType<T>::type,
174-
d_work, lwork * sizeof(T),
175-
d_info));
176-
177-
int info;
178-
cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
179-
if (info != 0) {
180-
throw std::runtime_error("cuSOLVER geqrf failed with info = " + std::to_string(info));
181-
}
138+
cuSolverConnector::geqrf(cusolver_handle, m, n, d_A, lda, d_tau);
182139

183-
// 4. Generate Q using orgqr
184-
// Query workspace for orgqr
185-
cusolverErrcheck(cusolverDnXorgqr_bufferSize(
186-
cusolverH, m, n, k,
187-
getCudaDataType<T>::type, d_A, lda,
188-
getCudaDataType<T>::type, d_tau,
189-
CUDA_R_32F,
190-
CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
191-
192-
lwork = static_cast<int>(workspaceInBytes / sizeof(T));
193-
cudaErrcheck(cudaRealloc(&d_work, sizeof(T) * lwork)); // or realloc
194-
195-
// orgqr: generate Q
196-
cusolverErrcheck(cusolverDnXorgqr(
197-
cusolverH, m, n, k,
198-
getCudaDataType<T>::type, d_A, lda,
199-
getCudaDataType<T>::type, d_tau,
200-
d_work, lwork * sizeof(T),
201-
d_info));
202-
203-
cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
204-
if (info != 0) {
205-
throw std::runtime_error("cuSOLVER orgqr failed with info = " + std::to_string(info));
206-
}
140+
cuSolverConnector::orgqr(cusolver_handle, m, n, k, d_A, lda, d_tau);
207141

208-
// Clean up
209142
cudaErrcheck(cudaFree(d_tau));
210-
cudaErrcheck(cudaFree(d_work));
211-
cudaErrcheck(cudaFree(d_info));
143+
144+
// // geqrf: workspace query
145+
146+
// // In practice, we use helper function to get lwork
147+
// // Or use magma for better interface
148+
// // Let's assume we have a way to get lwork
149+
// // For now, do a dummy call to get it
150+
// size_t workspaceInBytes = 0;
151+
// cusolverErrcheck(cusolverDnXgeqrf_bufferSize(
152+
// cusolverH, m, n,
153+
// getCudaDataType<T>::type, d_A, lda,
154+
// getCudaDataType<T>::type, // for tau
155+
// CUDA_R_32F, // numerical precision
156+
// CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
157+
158+
// lwork = static_cast<int>(workspaceInBytes / sizeof(T));
159+
160+
// // Allocate workspace
161+
// T *d_work;
162+
// cudaErrcheck(cudaMalloc(&d_work, sizeof(T) * lwork));
163+
164+
// // 3. Perform geqrf
165+
// cusolverErrcheck(cusolverDnXgeqrf(
166+
// cusolverH, m, n,
167+
// getCudaDataType<T>::type, d_A, lda,
168+
// d_tau,
169+
// getCudaDataType<T>::type,
170+
// d_work, lwork * sizeof(T),
171+
// d_info));
172+
173+
// int info;
174+
// cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
175+
// if (info != 0) {
176+
// throw std::runtime_error("cuSOLVER geqrf failed with info = " + std::to_string(info));
177+
// }
178+
179+
// // 4. Generate Q using orgqr
180+
// // Query workspace for orgqr
181+
// cusolverErrcheck(cusolverDnXorgqr_bufferSize(
182+
// cusolverH, m, n, k,
183+
// getCudaDataType<T>::type, d_A, lda,
184+
// getCudaDataType<T>::type, d_tau,
185+
// CUDA_R_32F,
186+
// CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
187+
188+
// lwork = static_cast<int>(workspaceInBytes / sizeof(T));
189+
// cudaErrcheck(cudaRealloc(&d_work, sizeof(T) * lwork)); // or realloc
190+
191+
// // orgqr: generate Q
192+
// cusolverErrcheck(cusolverDnXorgqr(
193+
// cusolverH, m, n, k,
194+
// getCudaDataType<T>::type, d_A, lda,
195+
// getCudaDataType<T>::type, d_tau,
196+
// d_work, lwork * sizeof(T),
197+
// d_info));
198+
199+
// cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
200+
// if (info != 0) {
201+
// throw std::runtime_error("cuSOLVER orgqr failed with info = " + std::to_string(info));
202+
// }
203+
204+
// // Clean up
205+
// cudaErrcheck(cudaFree(d_tau));
206+
// cudaErrcheck(cudaFree(d_work));
207+
// cudaErrcheck(cudaFree(d_info));
212208
}
213209
};
214210

@@ -391,7 +387,10 @@ template struct lapack_getri<double, DEVICE_GPU>;
391387
template struct lapack_getri<std::complex<float>, DEVICE_GPU>;
392388
template struct lapack_getri<std::complex<double>, DEVICE_GPU>;
393389

394-
390+
template struct lapack_geqrf_inplace<float, DEVICE_GPU>;
391+
template struct lapack_geqrf_inplace<double, DEVICE_GPU>;
392+
template struct lapack_geqrf_inplace<std::complex<float>, DEVICE_GPU>;
393+
template struct lapack_geqrf_inplace<std::complex<double>, DEVICE_GPU>;
395394

396395
} // namespace kernels
397396
} // namespace container

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,10 @@ template struct lapack_getrs<double, DEVICE_CPU>;
493493
template struct lapack_getrs<std::complex<float>, DEVICE_CPU>;
494494
template struct lapack_getrs<std::complex<double>, DEVICE_CPU>;
495495

496+
template struct lapack_geqrf_inplace<float, DEVICE_CPU>;
497+
template struct lapack_geqrf_inplace<double, DEVICE_CPU>;
498+
template struct lapack_geqrf_inplace<std::complex<float>, DEVICE_CPU>;
499+
template struct lapack_geqrf_inplace<std::complex<double>, DEVICE_CPU>;
496500

497501
template struct lapack_heevd<float, DEVICE_CPU>;
498502
template struct lapack_heevd<double, DEVICE_CPU>;

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ struct lapack_getri {
6767
// that will change input Mat A to orthogonal/unitary matrix Q
6868
template <typename T, typename Device>
6969
struct lapack_geqrf_inplace {
70+
/**
71+
* @brief Perform in-place QR factorization of a matrix using LAPACK's geqrf function.
72+
*
73+
* This function computes the QR factorization of an m-by-n matrix A as A = Q * R,
74+
* where Q is an orthogonal/unitary matrix and R is an upper triangular matrix.
75+
* The factorization is performed in-place, meaning the input matrix A will be modified.
76+
*
77+
* On exit: A is overwritten with the QR factorization Q orthogonal/unitary matrix
78+
*
79+
* @param m The number of rows in the matrix A. m >= 0
80+
* @param n The number of columns in the matrix A. n >= 0
81+
* @param A Pointer to the matrix A to be factorized. On exit, contains the QR factorization
82+
* @param lda The leading dimension of the matrix A. lda >= max(1, m)
83+
*/
7084
void operator()(
7185
const int m,
7286
const int n,

source/source_base/module_container/ATen/kernels/test/lapack_test.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,74 @@ TYPED_TEST(LapackTest, Potrf) {
9292
EXPECT_EQ(A, C);
9393
}
9494

95+
// lapack_geqrf_inplace,
96+
// check that QtQ = I
97+
TYPED_TEST(LapackTest, GeqrfInPlace) {
98+
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
99+
using Device = typename std::tuple_element<1, decltype(TypeParam())>::type;
100+
101+
lapack_geqrf_inplace<Type, Device> geqrfCalculator;
102+
103+
const int m = 4;
104+
const int n = 3; // m >= n,Q is m x n column-orthogonal matrix
105+
const int lda = m;
106+
107+
Tensor A_input = std::move(Tensor({
108+
static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0), static_cast<Type>(4.0),
109+
static_cast<Type>(5.0), static_cast<Type>(6.0), static_cast<Type>(7.0), static_cast<Type>(8.0),
110+
static_cast<Type>(9.0), static_cast<Type>(10.0), static_cast<Type>(11.0), static_cast<Type>(12.0)
111+
}).to_device<Device>());
112+
113+
Tensor A = A_input; // will be overwritten as Q
114+
115+
// do geqrf -> get orthogonal Q
116+
geqrfCalculator(m, n, A.data<Type>(), lda);
117+
118+
// check on CPU
119+
Tensor Q = A.to_device<DEVICE_CPU>();
120+
const Type* Q_data = Q.data<Type>();
121+
122+
// compute QtQ = Q^T * Q (n x n)
123+
Tensor QtQ = Q; // std::move(Tensor(std::vector<Type>(n * n, static_cast<Type>(0.0))).to_device<DEVICE_CPU>());
124+
const Type alpha = static_cast<Type>(1.0);
125+
const Type beta = static_cast<Type>(0.0);
126+
127+
blas_gemm<Type, DEVICE_CPU> gemm;
128+
gemm('C', 'N', // Q^T * Q
129+
n, n, m, // n x n
130+
&alpha,
131+
Q_data, lda, // Q^T
132+
Q_data, lda, // Q
133+
&beta,
134+
QtQ.data<Type>(), n);
135+
136+
// Test code: print A
137+
std::cout << "A = " << std::endl;
138+
for (int i = 0; i < m; ++i) {
139+
for (int j = 0; j < n; ++j) {
140+
std::cout << A_input.to_device<DEVICE_CPU>().data<Type>()[i + j * m] << " ";
141+
}
142+
std::cout << std::endl;
143+
}
144+
// Test code: print QtQ
145+
std::cout << "QtQ = " << std::endl;
146+
for (int i = 0; i < n; ++i) {
147+
for (int j = 0; j < n; ++j) {
148+
std::cout << QtQ.data<Type>()[i + j * n] << " ";
149+
}
150+
std::cout << std::endl;
151+
}
152+
153+
// check QtQ
154+
for (int i = 0; i < n; ++i) {
155+
for (int j = 0; j < n; ++j) {
156+
Type expected = (i == j) ? static_cast<Type>(1.0) : static_cast<Type>(0.0);
157+
EXPECT_NEAR(std::abs(QtQ.data<Type>()[i + j * n]), std::abs(expected), 1e-5)
158+
<< "Q^T * Q not identity at (" << i << "," << j << ")";
159+
}
160+
}
161+
}
162+
95163
// Test for lapack_heevd and lapack_heevx:
96164
// Solve a standard eigenvalue problem
97165
// and check that A*V = V*E

source/source_base/module_container/base/third_party/blas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void caxpy_(const int *N, const std::complex<float> *alpha, const std::complex<f
2626
void zaxpy_(const int *N, const std::complex<double> *alpha, const std::complex<double> *x, const int *incx, std::complex<double> *y, const int *incy);
2727

2828
void scopy_(const int *n, const float *a, const int *incx, float *b, int const *incy);
29-
void dcopy_(const int *n, const double *a, const *incx, double *b, int const *incy);
29+
void dcopy_(const int *n, const double *a, const int *incx, double *b, int const *incy);
3030
void ccopy_(const int *n, const std::complex<float> *a, const int *incx, std::complex<float> *b, int const *incy);
3131
void zcopy_(const int *n, const std::complex<double> *a, const int *incx, std::complex<double> *b, int const *incy);
3232

0 commit comments

Comments
 (0)