1+ #include " source_base/module_device/types.h"
12#include < ATen/kernels/lapack.h>
23
34#include < base/third_party/lapack.h>
45
56// #include <cstring> // std::memcpy
67#include < algorithm> // std::copy
8+ #include < complex>
79#include < stdexcept>
810#include < string>
911
@@ -108,15 +110,15 @@ struct lapack_heevx<T, DEVICE_CPU> {
108110 void operator ()(
109111 const int n,
110112 const int lda,
111- T *Mat,
113+ const T *Mat,
112114 const int neig,
113115 Real *eigen_val,
114116 T *eigen_vec)
115117 {
116118 Tensor aux (DataTypeToEnum<T>::value, DeviceType::CpuDevice, {n * lda});
117119 // Copy Mat to aux since heevx will destroy it
118120 // aux = Mat
119- std::copy (Mat, Mat + n * lda, aux);
121+ std::copy (Mat, Mat + n * lda, aux. data <T>() );
120122
121123 char jobz = ' V' ; // Compute eigenvalues and eigenvectors
122124 char range = ' I' ; // Find eigenvalues in index range [il, iu]
@@ -139,17 +141,17 @@ struct lapack_heevx<T, DEVICE_CPU> {
139141 // when lwork = -1
140142 lapackConnector::heevx (
141143 jobz, range, uplo, n,
142- aux, lda,
144+ aux. data <T>() , lda,
143145 0.0 , 0.0 , il, iu, // vl, vu not used when range='I'
144146 abstol,
145- & found,
147+ found,
146148 eigen_val,
147149 eigen_vec, lda,
148150 &work_query, lwork,
149151 &rwork_query,
150152 &iwork_query,
151153 &ifail_query,
152- & info);
154+ info);
153155
154156 if (info != 0 ) {
155157 throw std::runtime_error (" heevx workspace query failed with info = " + std::to_string (info));
@@ -173,17 +175,17 @@ struct lapack_heevx<T, DEVICE_CPU> {
173175 // Actual call to heevx
174176 lapackConnector::heevx (
175177 jobz, range, uplo, n,
176- aux, lda,
178+ aux. data <T>() , lda,
177179 0.0 , 0.0 , il, iu,
178180 abstol,
179- & found,
181+ found,
180182 eigen_val,
181183 eigen_vec, lda,
182184 work.data <T>(), lwork,
183185 rwork.data <Real>(),
184186 iwork.data <int >(),
185187 ifail.data <int >(),
186- & info);
188+ info);
187189
188190 if (info != 0 ) {
189191 throw std::runtime_error (" heevx failed with info = " + std::to_string (info));
@@ -388,6 +390,11 @@ template struct lapack_heevd<double, DEVICE_CPU>;
388390template struct lapack_heevd <std::complex <float >, DEVICE_CPU>;
389391template struct lapack_heevd <std::complex <double >, DEVICE_CPU>;
390392
393+ template struct lapack_heevx <float , DEVICE_CPU>;
394+ template struct lapack_heevx <double , DEVICE_CPU>;
395+ template struct lapack_heevx <std::complex <float >, DEVICE_CPU>;
396+ template struct lapack_heevx <std::complex <double >, DEVICE_CPU>;
397+
391398template struct lapack_hegvd <float , DEVICE_CPU>;
392399template struct lapack_hegvd <double , DEVICE_CPU>;
393400template struct lapack_hegvd <std::complex <float >, DEVICE_CPU>;
0 commit comments