Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ if(BUILD_TESTING)
if(ENABLE_MPI)
add_subdirectory(test)
endif()
endif()
endif()
103 changes: 97 additions & 6 deletions source/source_base/module_container/ATen/kernels/cuda/lapack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include <cuda_runtime.h>
#include <thrust/complex.h>

#include <cassert>


namespace container {
namespace kernels {

Expand Down Expand Up @@ -101,22 +104,100 @@ struct lapack_heevd<T, DEVICE_GPU> {
}
};

template <typename T>
struct lapack_heevx<T, DEVICE_GPU> {
using Real = typename GetTypeReal<T>::type;
void operator()(
const int n,
const int lda,
const T *d_Mat,
const int neig,
Real *d_eigen_val,
T *d_eigen_vec)
{
assert(n <= lda);
// copy d_Mat to d_eigen_vec, and results will be overwritten into d_eigen_vec
// by cuSolver
cudaErrcheck(cudaMemcpy(d_eigen_vec, d_Mat, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice));

int meig = 0;

cuSolverConnector::heevdx(
cusolver_handle,
n,
lda,
d_eigen_vec,
'V', // jobz: compute vectors
'L', // uplo: lower triangle
'I', // range: by index
1, neig, // il, iu
Real(0), Real(0), // vl, vu (unused)
d_eigen_val,
&meig
);

}
};
template <typename T>
struct lapack_hegvd<T, DEVICE_GPU> {
using Real = typename GetTypeReal<T>::type;
void operator()(
const int& itype,
const char& jobz,
const char& uplo,
const int dim,
const int lda,
T* Mat_A,
T* Mat_B,
const int& dim,
Real* eigen_val)
Real* eigen_val,
T *eigen_vec)
{
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val);
const int itype = 1;
const char jobz = 'V';
const char uplo = 'L';
cudaErrcheck(cudaMemcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));

// prevent B from being overwritten by Cholesky
T *d_B_backup = nullptr;
cudaErrcheck(cudaMalloc(&d_B_backup, sizeof(T) * dim * lda));
cudaErrcheck(cudaMemcpy(d_B_backup, Mat_B, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));

cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim,
eigen_vec, lda,
d_B_backup, lda,
eigen_val);
cudaErrcheck(cudaFree(d_B_backup));
}
};

template <typename T>
struct lapack_hegvx<T, DEVICE_GPU> {
using Real = typename GetTypeReal<T>::type;
void operator()(
const int n,
const int lda,
T *A,
T *B,
const int m,
Real *eigen_val,
T *eigen_vec)
{
const int itype = 1;
const char jobz = 'V';
const char range = 'I';
const char uplo = 'U';
int meig = 0;

// this hegvdx will protect the input A, B from being overwritten
// and write the eigenvectors into eigen_vec.
cuSolverConnector::hegvdx(cusolver_handle,
itype, jobz, range, uplo,
n, lda, A, B,
Real(0), Real(0),
1, m, &meig,
eigen_val, eigen_vec);
}
};



template <typename T>
struct lapack_getrf<T, DEVICE_GPU> {
void operator()(
Expand Down Expand Up @@ -180,11 +261,21 @@ template struct lapack_heevd<double, DEVICE_GPU>;
template struct lapack_heevd<std::complex<float>, DEVICE_GPU>;
template struct lapack_heevd<std::complex<double>, DEVICE_GPU>;

template struct lapack_heevx<float, DEVICE_GPU>;
template struct lapack_heevx<double, DEVICE_GPU>;
template struct lapack_heevx<std::complex<float>, DEVICE_GPU>;
template struct lapack_heevx<std::complex<double>, DEVICE_GPU>;

template struct lapack_hegvd<float, DEVICE_GPU>;
template struct lapack_hegvd<double, DEVICE_GPU>;
template struct lapack_hegvd<std::complex<float>, DEVICE_GPU>;
template struct lapack_hegvd<std::complex<double>, DEVICE_GPU>;

template struct lapack_hegvx<float, DEVICE_GPU>;
template struct lapack_hegvx<double, DEVICE_GPU>;
template struct lapack_hegvx<std::complex<float>, DEVICE_GPU>;
template struct lapack_hegvx<std::complex<double>, DEVICE_GPU>;

template struct lapack_getrf<float, DEVICE_GPU>;
template struct lapack_getrf<double, DEVICE_GPU>;
template struct lapack_getrf<std::complex<float>, DEVICE_GPU>;
Expand Down
Loading
Loading