Skip to content

Commit 82b8acf

Browse files
committed
Add lapack_heevx
1 parent 679286e commit 82b8acf

File tree

5 files changed

+434
-4
lines changed

5 files changed

+434
-4
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,39 @@ struct lapack_heevd<T, DEVICE_GPU> {
101101
}
102102
};
103103

104+
template <typename T>
105+
struct lapack_heevx<T, DEVICE_GPU> {
106+
using Real = typename GetTypeReal<T>::type;
107+
void operator()(
108+
const int n,
109+
const int lda,
110+
T *d_Mat,
111+
const int neig,
112+
Real *d_eigen_val,
113+
T *d_eigen_vec)
114+
{
115+
// copy d_Mat to d_eigen_vec, and results will be overwritten into d_eigen_vec
116+
// by cuSolver
117+
cudaErrcheck(cudaMemcpy(d_eigen_vec, d_Mat, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice));
118+
119+
int meig = 0;
120+
121+
cuSolverConnector::heevdx(
122+
cusolver_handle,
123+
n,
124+
lda,
125+
d_eigen_vec,
126+
'V', // jobz: compute vectors
127+
'L', // uplo: lower triangle
128+
'I', // range: by index
129+
1, neig, // il, iu
130+
Real(0), Real(0), // vl, vu (unused)
131+
d_eigen_val,
132+
&meig
133+
);
134+
135+
}
136+
};
104137
template <typename T>
105138
struct lapack_hegvd<T, DEVICE_GPU> {
106139
using Real = typename GetTypeReal<T>::type;

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44

55
// #include <cstring> // std::memcpy
66
#include <algorithm> // std::copy
7+
#include <stdexcept>
8+
#include <string>
79

810
namespace container {
911
namespace kernels {
1012

13+
inline double get_real(const std::complex<double> &x) { return x.real(); }
14+
inline float get_real(const std::complex<float> &x) { return x.real(); }
15+
inline double get_real(const double &x) { return x; }
16+
inline float get_real(const float &x) { return x; }
17+
1118
template <typename T>
1219
struct set_matrix<T, DEVICE_CPU> {
1320
void operator() (
@@ -95,6 +102,96 @@ struct lapack_heevd<T, DEVICE_CPU> {
95102
}
96103
};
97104

105+
template <typename T>
106+
struct lapack_heevx<T, DEVICE_CPU> {
107+
using Real = typename GetTypeReal<T>::type;
108+
void operator()(
109+
const int n,
110+
const int lda,
111+
T *Mat,
112+
const int neig,
113+
Real *eigen_val,
114+
T *eigen_vec)
115+
{
116+
Tensor aux(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {n * lda});
117+
// Copy Mat to aux since heevx will destroy it
118+
// aux = Mat
119+
std::copy(Mat, Mat + n * lda, aux);
120+
121+
char jobz = 'V'; // Compute eigenvalues and eigenvectors
122+
char range = 'I'; // Find eigenvalues in index range [il, iu]
123+
char uplo = 'L'; // Use Lower triangle
124+
int info = 0;
125+
int found = 0; // Number of eigenvalues found
126+
// found should be iu - il + 1, i.e. found = neig
127+
const int il = 1;
128+
const int iu = neig;
129+
Real abstol = 0.0;
130+
131+
// Workspace query first
132+
int lwork = -1;
133+
T work_query;
134+
Real rwork_query;
135+
int iwork_query;
136+
int ifail_query;
137+
138+
// Dummy call to get optimal workspace size
139+
// when lwork = -1
140+
lapackConnector::heevx(
141+
jobz, range, uplo, n,
142+
aux, lda,
143+
0.0, 0.0, il, iu, // vl, vu not used when range='I'
144+
abstol,
145+
&found,
146+
eigen_val,
147+
eigen_vec, lda,
148+
&work_query, lwork,
149+
&rwork_query,
150+
&iwork_query,
151+
&ifail_query,
152+
&info);
153+
154+
if (info != 0) {
155+
throw std::runtime_error("heevx workspace query failed with info = " + std::to_string(info));
156+
}
157+
158+
lwork = static_cast<int>(get_real(work_query));
159+
160+
// Allocate buffers using Tensor (RAII)
161+
Tensor work(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {lwork});
162+
work.zero();
163+
164+
Tensor rwork(DataTypeToEnum<Real>::value, DeviceType::CpuDevice, {7 * n});
165+
rwork.zero();
166+
167+
Tensor iwork(DataType::DT_INT, DeviceType::CpuDevice, {5 * n});
168+
iwork.zero();
169+
170+
Tensor ifail(DataType::DT_INT, DeviceType::CpuDevice, {n});
171+
ifail.zero();
172+
173+
// Actual call to heevx
174+
lapackConnector::heevx(
175+
jobz, range, uplo, n,
176+
aux, lda,
177+
0.0, 0.0, il, iu,
178+
abstol,
179+
&found,
180+
eigen_val,
181+
eigen_vec, lda,
182+
work.data<T>(), lwork,
183+
rwork.data<Real>(),
184+
iwork.data<int>(),
185+
ifail.data<int>(),
186+
&info);
187+
188+
if (info != 0) {
189+
throw std::runtime_error("heevx failed with info = " + std::to_string(info));
190+
}
191+
192+
}
193+
};
194+
98195
template <typename T>
99196
struct lapack_hegvd<T, DEVICE_CPU> {
100197
using Real = typename GetTypeReal<T>::type;

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef ATEN_KERNELS_LAPACK_H_
22
#define ATEN_KERNELS_LAPACK_H_
33

4+
#include "source_base/macros.h"
45
#include <ATen/core/tensor.h>
56
#include <ATen/core/tensor_types.h>
67

@@ -51,6 +52,40 @@ struct lapack_heevd {
5152
Real* eigen_val);
5253
};
5354

55+
template <typename T, typename Device>
56+
struct lapack_heevx {
57+
using Real = typename GetTypeReal<T>::type;
58+
/**
59+
* @brief Computes selected eigenvalues and, optionally, eigenvectors of a complex Hermitian matrix.
60+
*
61+
* This function solves the problem A*x = lambda*x, where A is a Hermitian matrix.
62+
* It computes a subset of eigenvalues and, optionally, the corresponding eigenvectors.
63+
*
64+
* @param jobz 'N': Compute eigenvalues only; 'V': Compute eigenvalues and eigenvectors.
65+
* @param range 'A': All eigenvalues; 'V': Eigenvalues in the half-open interval (vl, vu]; 'I': Eigenvalues with indices il through iu.
66+
* @param uplo 'U': Upper triangle of A is stored; 'L': Lower triangle is stored.
67+
* @param dim The order of the matrix A. dim >= 0.
68+
* @param Mat On entry, the Hermitian matrix A. On exit, it may be overwritten.
69+
* @param vl Lower bound of the interval to search for eigenvalues if range == 'V'.
70+
* @param vu Upper bound of the interval to search for eigenvalues if range == 'V'.
71+
* @param il Index of the smallest eigenvalue to be returned if range == 'I'.
72+
* @param iu Index of the largest eigenvalue to be returned if range == 'I'.
73+
* @param m Output: The total number of found eigenvalues.
74+
* @param eigen_val Array to store the computed eigenvalues in ascending order.
75+
* @param eigen_vec If not nullptr and jobz == 'V', array to store the computed eigenvectors.
76+
*
77+
* @note
78+
* See LAPACK ZHEEVX or CHEEVX documentation for more details.
79+
*
80+
*/
81+
void operator()(
82+
const int dim,
83+
const int lda,
84+
T *Mat,
85+
const int neig,
86+
Real *eigen_val,
87+
T *eigen_vec);
88+
};
5489

5590
template <typename T, typename Device>
5691
struct lapack_hegvd {
@@ -60,8 +95,8 @@ struct lapack_hegvd {
6095
*
6196
* This function solves the problem A*x = lambda*B*x, where A and B are Hermitian matrices, and B is also positive definite.
6297
*
63-
* @param dim The order of the matrices Mat_A and Mat_B. dim >= 0.
64-
* @param lda The leading dimension of the arrays Mat_A and Mat_B. lda >= max(1, dim).
98+
* @param n The order of the matrices Mat_A and Mat_B. n >= 0.
99+
* @param lda The leading dimension of the arrays Mat_A and Mat_B. lda >= max(1, n).
65100
* @param Mat_A On entry, the Hermitian matrix A. On exit, it may be overwritten.
66101
* @param Mat_B On entry, the Hermitian positive definite matrix B. On exit, it may be overwritten.
67102
* @param eigen_val Array to store the computed eigenvalues in ascending order.
@@ -72,7 +107,7 @@ struct lapack_hegvd {
72107
* This function assumes that A and B have the same leading dimensions, lda.
73108
*/
74109
void operator()(
75-
const int dim,
110+
const int n,
76111
const int lda,
77112
T *Mat_A,
78113
T *Mat_B,

source/source_base/module_container/base/macros/cuda.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,29 @@ static inline cusolverEigType_t cublas_eig_type(const int& itype)
121121
throw std::runtime_error("cublas_eig_mode: unknown diag");
122122
}
123123

124+
/**
125+
* @brief Converts a character specifying eigenvalue range to cuSOLVER enum.
126+
*
127+
* 'A' or 'a' -> CUSOLVER_EIG_RANGE_ALL: all eigenvalues
128+
* 'V' or 'v' -> CUSOLVER_EIG_RANGE_V: values in [vl, vu]
129+
* 'I' or 'i' -> CUSOLVER_EIG_RANGE_I: indices in [il, iu]
130+
*
131+
* @param range Character indicating selection mode ('A', 'V', 'I')
132+
* @return Corresponding cusolverEigRange_t enum value
133+
* @throws std::runtime_error if character is invalid
134+
*/
135+
static inline cusolverEigRange_t cublas_eig_range(const char& range)
136+
{
137+
if (range == 'A' || range == 'a')
138+
return CUSOLVER_EIG_RANGE_ALL;
139+
else if (range == 'V' || range == 'v')
140+
return CUSOLVER_EIG_RANGE_V;
141+
else if (range == 'I' || range == 'i')
142+
return CUSOLVER_EIG_RANGE_I;
143+
else
144+
throw std::runtime_error("cublas_eig_range: unknown range '" + std::string(1, range) + "'");
145+
}
146+
124147
// cuSOLVER API errors
125148
static const char* cusolverGetErrorEnum(cusolverStatus_t error)
126149
{
@@ -226,4 +249,4 @@ inline void cublasAssert(cublasStatus_t res, const char* file, int line)
226249
#define cudaCheckOnDebug()
227250
#endif
228251

229-
#endif // BASE_MACROS_CUDA_H_
252+
#endif // BASE_MACROS_CUDA_H_

0 commit comments

Comments
 (0)