Skip to content

Commit 4870f18

Browse files
authored
Merge branch 'develop' into develop
2 parents 73965ea + f689729 commit 4870f18

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1587
-310
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ __pycache__
2323
abacus.json
2424
*.npy
2525
toolchain/install/
26-
toolchain/abacus_env.sh
26+
toolchain/abacus_env.sh
27+
.trae

source/source_base/element_basis_index.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,26 @@
44
//==========================================================
55

66
#include "element_basis_index.h"
7+
78
namespace ModuleBase
89
{
910

10-
Element_Basis_Index::IndexLNM Element_Basis_Index::construct_index( const Range &range )
11+
Element_Basis_Index::IndexLNM
12+
Element_Basis_Index::construct_index( const Range &range )
1113
{
1214
IndexLNM index;
1315
index.resize( range.size() );
14-
for( size_t T=0; T!=range.size(); ++T )
16+
for( std::size_t T=0; T!=range.size(); ++T )
1517
{
16-
size_t count=0;
18+
std::size_t count=0;
1719
index[T].resize( range[T].size() );
18-
for( size_t L=0; L!=range[T].size(); ++L )
20+
for( std::size_t L=0; L!=range[T].size(); ++L )
1921
{
2022
index[T][L].resize( range[T][L].N );
21-
for( size_t N=0; N!=range[T][L].N; ++N )
23+
for( std::size_t N=0; N!=range[T][L].N; ++N )
2224
{
2325
index[T][L][N].resize( range[T][L].M );
24-
for( size_t M=0; M!=range[T][L].M; ++M )
26+
for( std::size_t M=0; M!=range[T][L].M; ++M )
2527
{
2628
index[T][L][N][M] = count;
2729
++count;

source/source_base/element_basis_index.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,41 @@
88

99
#include <cstddef>
1010
#include <vector>
11+
1112
namespace ModuleBase
1213
{
1314

14-
class Element_Basis_Index
15+
namespace Element_Basis_Index
1516
{
16-
private:
17-
17+
//private:
18+
1819
struct NM
1920
{
2021
public:
21-
size_t N;
22-
size_t M;
22+
std::size_t N;
23+
std::size_t M;
2324
};
24-
25-
class Index_TL: public std::vector<std::vector<size_t>>
25+
26+
class Index_TL: public std::vector<std::vector<std::size_t>>
2627
{
2728
public:
28-
size_t N;
29-
size_t M;
29+
std::size_t N;
30+
std::size_t M;
3031
};
31-
32+
3233
class Index_T: public std::vector<Index_TL>
3334
{
3435
public:
35-
size_t count_size;
36-
};
37-
38-
public:
39-
40-
typedef std::vector<std::vector<NM>> Range; // range[T][L]
36+
std::size_t count_size;
37+
};
38+
39+
//public:
40+
41+
typedef std::vector<std::vector<NM>> Range; // range[T][L]
4142
typedef std::vector<Index_T> IndexLNM; // index[T][L][N][M]
42-
43-
static IndexLNM construct_index( const Range &range );
44-
};
43+
44+
extern IndexLNM construct_index( const Range &range );
45+
}
4546

4647
}
4748

source/source_base/module_container/ATen/kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ if(BUILD_TESTING)
1616
if(ENABLE_MPI)
1717
add_subdirectory(test)
1818
endif()
19-
endif()
19+
endif()

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include <cuda_runtime.h>
77
#include <thrust/complex.h>
88

9+
#include <cassert>
10+
11+
912
namespace container {
1013
namespace kernels {
1114

@@ -101,22 +104,100 @@ struct lapack_heevd<T, DEVICE_GPU> {
101104
}
102105
};
103106

107+
template <typename T>
108+
struct lapack_heevx<T, DEVICE_GPU> {
109+
using Real = typename GetTypeReal<T>::type;
110+
void operator()(
111+
const int n,
112+
const int lda,
113+
const T *d_Mat,
114+
const int neig,
115+
Real *d_eigen_val,
116+
T *d_eigen_vec)
117+
{
118+
assert(n <= lda);
119+
// copy d_Mat to d_eigen_vec, and results will be overwritten into d_eigen_vec
120+
// by cuSolver
121+
cudaErrcheck(cudaMemcpy(d_eigen_vec, d_Mat, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice));
122+
123+
int meig = 0;
124+
125+
cuSolverConnector::heevdx(
126+
cusolver_handle,
127+
n,
128+
lda,
129+
d_eigen_vec,
130+
'V', // jobz: compute vectors
131+
'L', // uplo: lower triangle
132+
'I', // range: by index
133+
1, neig, // il, iu
134+
Real(0), Real(0), // vl, vu (unused)
135+
d_eigen_val,
136+
&meig
137+
);
138+
139+
}
140+
};
104141
template <typename T>
105142
struct lapack_hegvd<T, DEVICE_GPU> {
106143
using Real = typename GetTypeReal<T>::type;
107144
void operator()(
108-
const int& itype,
109-
const char& jobz,
110-
const char& uplo,
145+
const int dim,
146+
const int lda,
111147
T* Mat_A,
112148
T* Mat_B,
113-
const int& dim,
114-
Real* eigen_val)
149+
Real* eigen_val,
150+
T *eigen_vec)
115151
{
116-
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val);
152+
const int itype = 1;
153+
const char jobz = 'V';
154+
const char uplo = 'L';
155+
cudaErrcheck(cudaMemcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));
156+
157+
// prevent B from being overwritten by Cholesky
158+
T *d_B_backup = nullptr;
159+
cudaErrcheck(cudaMalloc(&d_B_backup, sizeof(T) * dim * lda));
160+
cudaErrcheck(cudaMemcpy(d_B_backup, Mat_B, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));
161+
162+
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim,
163+
eigen_vec, lda,
164+
d_B_backup, lda,
165+
eigen_val);
166+
cudaErrcheck(cudaFree(d_B_backup));
167+
}
168+
};
169+
170+
template <typename T>
171+
struct lapack_hegvx<T, DEVICE_GPU> {
172+
using Real = typename GetTypeReal<T>::type;
173+
void operator()(
174+
const int n,
175+
const int lda,
176+
T *A,
177+
T *B,
178+
const int m,
179+
Real *eigen_val,
180+
T *eigen_vec)
181+
{
182+
const int itype = 1;
183+
const char jobz = 'V';
184+
const char range = 'I';
185+
const char uplo = 'U';
186+
int meig = 0;
187+
188+
// this hegvdx will protect the input A, B from being overwritten
189+
// and write the eigenvectors into eigen_vec.
190+
cuSolverConnector::hegvdx(cusolver_handle,
191+
itype, jobz, range, uplo,
192+
n, lda, A, B,
193+
Real(0), Real(0),
194+
1, m, &meig,
195+
eigen_val, eigen_vec);
117196
}
118197
};
119198

199+
200+
120201
template <typename T>
121202
struct lapack_getrf<T, DEVICE_GPU> {
122203
void operator()(
@@ -180,11 +261,21 @@ template struct lapack_heevd<double, DEVICE_GPU>;
180261
template struct lapack_heevd<std::complex<float>, DEVICE_GPU>;
181262
template struct lapack_heevd<std::complex<double>, DEVICE_GPU>;
182263

264+
template struct lapack_heevx<float, DEVICE_GPU>;
265+
template struct lapack_heevx<double, DEVICE_GPU>;
266+
template struct lapack_heevx<std::complex<float>, DEVICE_GPU>;
267+
template struct lapack_heevx<std::complex<double>, DEVICE_GPU>;
268+
183269
template struct lapack_hegvd<float, DEVICE_GPU>;
184270
template struct lapack_hegvd<double, DEVICE_GPU>;
185271
template struct lapack_hegvd<std::complex<float>, DEVICE_GPU>;
186272
template struct lapack_hegvd<std::complex<double>, DEVICE_GPU>;
187273

274+
template struct lapack_hegvx<float, DEVICE_GPU>;
275+
template struct lapack_hegvx<double, DEVICE_GPU>;
276+
template struct lapack_hegvx<std::complex<float>, DEVICE_GPU>;
277+
template struct lapack_hegvx<std::complex<double>, DEVICE_GPU>;
278+
188279
template struct lapack_getrf<float, DEVICE_GPU>;
189280
template struct lapack_getrf<double, DEVICE_GPU>;
190281
template struct lapack_getrf<std::complex<float>, DEVICE_GPU>;

0 commit comments

Comments
 (0)