Skip to content

Commit 93da750

Browse files
committed
Fix lapack_heevx and add template instantiation
1 parent 82b8acf commit 93da750

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct lapack_heevx<T, DEVICE_GPU> {
107107
void operator()(
108108
const int n,
109109
const int lda,
110-
T *d_Mat,
110+
const T *d_Mat,
111111
const int neig,
112112
Real *d_eigen_val,
113113
T *d_eigen_vec)
@@ -232,6 +232,11 @@ template struct lapack_heevd<double, DEVICE_GPU>;
232232
template struct lapack_heevd<std::complex<float>, DEVICE_GPU>;
233233
template struct lapack_heevd<std::complex<double>, DEVICE_GPU>;
234234

235+
template struct lapack_heevx<float, DEVICE_GPU>;
236+
template struct lapack_heevx<double, DEVICE_GPU>;
237+
template struct lapack_heevx<std::complex<float>, DEVICE_GPU>;
238+
template struct lapack_heevx<std::complex<double>, DEVICE_GPU>;
239+
235240
template struct lapack_hegvd<float, DEVICE_GPU>;
236241
template struct lapack_hegvd<double, DEVICE_GPU>;
237242
template struct lapack_hegvd<std::complex<float>, DEVICE_GPU>;

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
#include "source_base/module_device/types.h"
12
#include <ATen/kernels/lapack.h>
23

34
#include <base/third_party/lapack.h>
45

56
// #include <cstring> // std::memcpy
67
#include <algorithm> // std::copy
8+
#include <complex>
79
#include <stdexcept>
810
#include <string>
911

@@ -108,15 +110,15 @@ struct lapack_heevx<T, DEVICE_CPU> {
108110
void operator()(
109111
const int n,
110112
const int lda,
111-
T *Mat,
113+
const T *Mat,
112114
const int neig,
113115
Real *eigen_val,
114116
T *eigen_vec)
115117
{
116118
Tensor aux(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {n * lda});
117119
// Copy Mat to aux since heevx will destroy it
118120
// aux = Mat
119-
std::copy(Mat, Mat + n * lda, aux);
121+
std::copy(Mat, Mat + n * lda, aux.data<T>());
120122

121123
char jobz = 'V'; // Compute eigenvalues and eigenvectors
122124
char range = 'I'; // Find eigenvalues in index range [il, iu]
@@ -139,17 +141,17 @@ struct lapack_heevx<T, DEVICE_CPU> {
139141
// when lwork = -1
140142
lapackConnector::heevx(
141143
jobz, range, uplo, n,
142-
aux, lda,
144+
aux.data<T>(), lda,
143145
0.0, 0.0, il, iu, // vl, vu not used when range='I'
144146
abstol,
145-
&found,
147+
found,
146148
eigen_val,
147149
eigen_vec, lda,
148150
&work_query, lwork,
149151
&rwork_query,
150152
&iwork_query,
151153
&ifail_query,
152-
&info);
154+
info);
153155

154156
if (info != 0) {
155157
throw std::runtime_error("heevx workspace query failed with info = " + std::to_string(info));
@@ -173,17 +175,17 @@ struct lapack_heevx<T, DEVICE_CPU> {
173175
// Actual call to heevx
174176
lapackConnector::heevx(
175177
jobz, range, uplo, n,
176-
aux, lda,
178+
aux.data<T>(), lda,
177179
0.0, 0.0, il, iu,
178180
abstol,
179-
&found,
181+
found,
180182
eigen_val,
181183
eigen_vec, lda,
182184
work.data<T>(), lwork,
183185
rwork.data<Real>(),
184186
iwork.data<int>(),
185187
ifail.data<int>(),
186-
&info);
188+
info);
187189

188190
if (info != 0) {
189191
throw std::runtime_error("heevx failed with info = " + std::to_string(info));
@@ -388,6 +390,11 @@ template struct lapack_heevd<double, DEVICE_CPU>;
388390
template struct lapack_heevd<std::complex<float>, DEVICE_CPU>;
389391
template struct lapack_heevd<std::complex<double>, DEVICE_CPU>;
390392

393+
template struct lapack_heevx<float, DEVICE_CPU>;
394+
template struct lapack_heevx<double, DEVICE_CPU>;
395+
template struct lapack_heevx<std::complex<float>, DEVICE_CPU>;
396+
template struct lapack_heevx<std::complex<double>, DEVICE_CPU>;
397+
391398
template struct lapack_hegvd<float, DEVICE_CPU>;
392399
template struct lapack_hegvd<double, DEVICE_CPU>;
393400
template struct lapack_hegvd<std::complex<float>, DEVICE_CPU>;

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ struct lapack_heevx {
8181
void operator()(
8282
const int dim,
8383
const int lda,
84-
T *Mat,
84+
const T *Mat,
8585
const int neig,
8686
Real *eigen_val,
8787
T *eigen_vec);

0 commit comments

Comments
 (0)