Skip to content

Commit d9afb3e

Browse files
authored
Refactor: Unify standard/generalized eig driver of LAPACK (#6610)
* Update lapack_hegvd interface to support lda different than n * Replace hegvd_op with lapack_hegvd in diago_dav_subspace * Remove itype parameter for heevx that is required only in gv * Update heevx to use correct arg list * Add lapack_heevx * Fix lapack_heevx and add template instantiation * Replace hsolver::heevx_op with ct::kernels::lapack_heevx in diago_david * Add lapack_hegvx * Fix hip hegvd * Fix hegvd to prevent B from being overwritten by Cholesky * Switch on ATen/kernels/test * Change container test TARGET name to be auto-run by CI * Add test for heevx * Add test for hegvx * Remove test output code * Clean the code and add docs * Fix mismatched parentheses
1 parent 507c8ff commit d9afb3e

File tree

17 files changed

+1334
-143
lines changed

17 files changed

+1334
-143
lines changed

source/source_base/module_container/ATen/kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ if(BUILD_TESTING)
1616
if(ENABLE_MPI)
1717
add_subdirectory(test)
1818
endif()
19-
endif()
19+
endif()

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include <cuda_runtime.h>
77
#include <thrust/complex.h>
88

9+
#include <cassert>
10+
11+
912
namespace container {
1013
namespace kernels {
1114

@@ -101,22 +104,100 @@ struct lapack_heevd<T, DEVICE_GPU> {
101104
}
102105
};
103106

107+
template <typename T>
108+
struct lapack_heevx<T, DEVICE_GPU> {
109+
using Real = typename GetTypeReal<T>::type;
110+
void operator()(
111+
const int n,
112+
const int lda,
113+
const T *d_Mat,
114+
const int neig,
115+
Real *d_eigen_val,
116+
T *d_eigen_vec)
117+
{
118+
assert(n <= lda);
119+
// copy d_Mat to d_eigen_vec, and results will be overwritten into d_eigen_vec
120+
// by cuSolver
121+
cudaErrcheck(cudaMemcpy(d_eigen_vec, d_Mat, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice));
122+
123+
int meig = 0;
124+
125+
cuSolverConnector::heevdx(
126+
cusolver_handle,
127+
n,
128+
lda,
129+
d_eigen_vec,
130+
'V', // jobz: compute vectors
131+
'L', // uplo: lower triangle
132+
'I', // range: by index
133+
1, neig, // il, iu
134+
Real(0), Real(0), // vl, vu (unused)
135+
d_eigen_val,
136+
&meig
137+
);
138+
139+
}
140+
};
104141
template <typename T>
105142
struct lapack_hegvd<T, DEVICE_GPU> {
106143
using Real = typename GetTypeReal<T>::type;
107144
void operator()(
108-
const int& itype,
109-
const char& jobz,
110-
const char& uplo,
145+
const int dim,
146+
const int lda,
111147
T* Mat_A,
112148
T* Mat_B,
113-
const int& dim,
114-
Real* eigen_val)
149+
Real* eigen_val,
150+
T *eigen_vec)
115151
{
116-
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val);
152+
const int itype = 1;
153+
const char jobz = 'V';
154+
const char uplo = 'L';
155+
cudaErrcheck(cudaMemcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));
156+
157+
// prevent B from being overwritten by Cholesky
158+
T *d_B_backup = nullptr;
159+
cudaErrcheck(cudaMalloc(&d_B_backup, sizeof(T) * dim * lda));
160+
cudaErrcheck(cudaMemcpy(d_B_backup, Mat_B, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));
161+
162+
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim,
163+
eigen_vec, lda,
164+
d_B_backup, lda,
165+
eigen_val);
166+
cudaErrcheck(cudaFree(d_B_backup));
167+
}
168+
};
169+
170+
template <typename T>
171+
struct lapack_hegvx<T, DEVICE_GPU> {
172+
using Real = typename GetTypeReal<T>::type;
173+
void operator()(
174+
const int n,
175+
const int lda,
176+
T *A,
177+
T *B,
178+
const int m,
179+
Real *eigen_val,
180+
T *eigen_vec)
181+
{
182+
const int itype = 1;
183+
const char jobz = 'V';
184+
const char range = 'I';
185+
const char uplo = 'U';
186+
int meig = 0;
187+
188+
// this hegvdx will protect the input A, B from being overwritten
189+
// and write the eigenvectors into eigen_vec.
190+
cuSolverConnector::hegvdx(cusolver_handle,
191+
itype, jobz, range, uplo,
192+
n, lda, A, B,
193+
Real(0), Real(0),
194+
1, m, &meig,
195+
eigen_val, eigen_vec);
117196
}
118197
};
119198

199+
200+
120201
template <typename T>
121202
struct lapack_getrf<T, DEVICE_GPU> {
122203
void operator()(
@@ -180,11 +261,21 @@ template struct lapack_heevd<double, DEVICE_GPU>;
180261
template struct lapack_heevd<std::complex<float>, DEVICE_GPU>;
181262
template struct lapack_heevd<std::complex<double>, DEVICE_GPU>;
182263

264+
template struct lapack_heevx<float, DEVICE_GPU>;
265+
template struct lapack_heevx<double, DEVICE_GPU>;
266+
template struct lapack_heevx<std::complex<float>, DEVICE_GPU>;
267+
template struct lapack_heevx<std::complex<double>, DEVICE_GPU>;
268+
183269
template struct lapack_hegvd<float, DEVICE_GPU>;
184270
template struct lapack_hegvd<double, DEVICE_GPU>;
185271
template struct lapack_hegvd<std::complex<float>, DEVICE_GPU>;
186272
template struct lapack_hegvd<std::complex<double>, DEVICE_GPU>;
187273

274+
template struct lapack_hegvx<float, DEVICE_GPU>;
275+
template struct lapack_hegvx<double, DEVICE_GPU>;
276+
template struct lapack_hegvx<std::complex<float>, DEVICE_GPU>;
277+
template struct lapack_hegvx<std::complex<double>, DEVICE_GPU>;
278+
188279
template struct lapack_getrf<float, DEVICE_GPU>;
189280
template struct lapack_getrf<double, DEVICE_GPU>;
190281
template struct lapack_getrf<std::complex<float>, DEVICE_GPU>;

0 commit comments

Comments
 (0)