Skip to content

Commit eda3add

Browse files
committed
RT-TDDFT GPU Acceleration (Phase 2): Adding needed BLAS and LAPACK support for Tensor on CPU and refactoring linear algebra operations in TDDFT
1 parent 3110720 commit eda3add

File tree

17 files changed

+924
-163
lines changed

17 files changed

+924
-163
lines changed

source/module_base/lapack_connector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ extern "C"
133133

134134
// zgetrf computes the LU factorization of a general matrix
135135
// while zgetri takes its output to perform matrix inversion
136-
void zgetrf_(const int* m, const int *n, const std::complex<double> *A, const int *lda, int *ipiv, const int* info);
137-
void zgetri_(const int* n, std::complex<double> *A, const int *lda, int *ipiv, std::complex<double> *work, int *lwork, const int *info);
136+
void zgetrf_(const int* m, const int *n, std::complex<double> *A, const int *lda, int *ipiv, int* info);
137+
void zgetri_(const int* n, std::complex<double>* A, const int* lda, const int* ipiv, std::complex<double>* work, const int* lwork, int* info);
138138

139139
// if trans=='N': C = alpha * A * A.H + beta * C
140140
// if trans=='C': C = alpha * A.H * A + beta * C

source/module_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,41 @@ struct lapack_dngvd<T, DEVICE_CPU> {
124124
}
125125
};
126126

127+
template <typename T>
128+
struct lapack_getrf<T, DEVICE_CPU> {
129+
void operator()(
130+
const int& m,
131+
const int& n,
132+
T* Mat,
133+
const int& lda,
134+
int* ipiv,
135+
int& info)
136+
{
137+
lapackConnector::getrf(m, n, Mat, lda, ipiv, info);
138+
if (info != 0) {
139+
throw std::runtime_error("getrf failed with info = " + std::to_string(info));
140+
}
141+
}
142+
};
143+
144+
template <typename T>
145+
struct lapack_getri<T, DEVICE_CPU> {
146+
void operator()(
147+
const int& n,
148+
T* Mat,
149+
const int& lda,
150+
const int* ipiv,
151+
T* work,
152+
const int& lwork,
153+
int& info)
154+
{
155+
lapackConnector::getri(n, Mat, lda, ipiv, work, lwork, info);
156+
if (info != 0) {
157+
throw std::runtime_error("getri failed with info = " + std::to_string(info));
158+
}
159+
}
160+
};
161+
127162
template struct set_matrix<float, DEVICE_CPU>;
128163
template struct set_matrix<double, DEVICE_CPU>;
129164
template struct set_matrix<std::complex<float>, DEVICE_CPU>;
@@ -149,5 +184,15 @@ template struct lapack_dngvd<double, DEVICE_CPU>;
149184
template struct lapack_dngvd<std::complex<float>, DEVICE_CPU>;
150185
template struct lapack_dngvd<std::complex<double>, DEVICE_CPU>;
151186

187+
template struct lapack_getrf<float, DEVICE_CPU>;
188+
template struct lapack_getrf<double, DEVICE_CPU>;
189+
template struct lapack_getrf<std::complex<float>, DEVICE_CPU>;
190+
template struct lapack_getrf<std::complex<double>, DEVICE_CPU>;
191+
192+
template struct lapack_getri<float, DEVICE_CPU>;
193+
template struct lapack_getri<double, DEVICE_CPU>;
194+
template struct lapack_getri<std::complex<float>, DEVICE_CPU>;
195+
template struct lapack_getri<std::complex<double>, DEVICE_CPU>;
196+
152197
} // namespace kernels
153198
} // namespace container

source/module_base/module_container/ATen/kernels/lapack.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ struct lapack_dngvd {
6565
Real* eigen_val);
6666
};
6767

68+
69+
template <typename T, typename Device>
70+
struct lapack_getrf {
71+
void operator()(
72+
const int& m,
73+
const int& n,
74+
T* Mat,
75+
const int& lda,
76+
int* ipiv,
77+
int& info);
78+
};
79+
80+
81+
template <typename T, typename Device>
82+
struct lapack_getri {
83+
void operator()(
84+
const int& n,
85+
T* Mat,
86+
const int& lda,
87+
const int* ipiv,
88+
T* work,
89+
const int& lwork,
90+
int& info);
91+
};
92+
93+
6894
#if defined(__CUDA) || defined(__ROCM)
6995
// TODO: Use C++ singleton to manage the GPU handles
7096
void createGpuSolverHandle(); // create cusolver handle

source/module_base/module_container/base/third_party/lapack.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ void dtrtri_(const char* uplo, const char* diag, const int* n, double* a, const
105105
void ctrtri_(const char* uplo, const char* diag, const int* n, std::complex<float>* a, const int* lda, int* info);
106106
void ztrtri_(const char* uplo, const char* diag, const int* n, std::complex<double>* a, const int* lda, int* info);
107107

108+
void sgetrf_(const int* m, const int* n, float* a, const int* lda, int* ipiv, int* info);
109+
void dgetrf_(const int* m, const int* n, double* a, const int* lda, int* ipiv, int* info);
110+
void cgetrf_(const int* m, const int* n, std::complex<float>* a, const int* lda, int* ipiv, int* info);
111+
void zgetrf_(const int* m, const int* n, std::complex<double>* a, const int* lda, int* ipiv, int* info);
112+
113+
void sgetri_(const int* n, float* A, const int* lda, const int* ipiv, float* work, const int* lwork, int* info);
114+
void dgetri_(const int* n, double* A, const int* lda, const int* ipiv, double* work, const int* lwork, int* info);
115+
void cgetri_(const int* n, std::complex<float>* A, const int* lda, const int* ipiv, std::complex<float>* work, const int* lwork, int* info);
116+
void zgetri_(const int* n, std::complex<double>* A, const int* lda, const int* ipiv, std::complex<double>* work, const int* lwork, int* info);
108117
}
109118

110119
// Class LapackConnector provide the connector to fortran lapack routine.
@@ -321,6 +330,48 @@ void trtri( const char &uplo, const char &diag, const int &n, std::complex<doubl
321330
ztrtri_( &uplo, &diag, &n, A, &lda, &info);
322331
}
323332

333+
static inline
334+
void getrf(const int &m, const int &n, float* A, const int &lda, int* ipiv, int &info)
335+
{
336+
sgetrf_(&m, &n, A, &lda, ipiv, &info);
337+
}
338+
static inline
339+
void getrf(const int &m, const int &n, double* A, const int &lda, int* ipiv, int &info)
340+
{
341+
dgetrf_(&m, &n, A, &lda, ipiv, &info);
342+
}
343+
static inline
344+
void getrf(const int &m, const int &n, std::complex<float>* A, const int &lda, int* ipiv, int &info)
345+
{
346+
cgetrf_(&m, &n, A, &lda, ipiv, &info);
347+
}
348+
static inline
349+
void getrf(const int &m, const int &n, std::complex<double>* A, const int &lda, int* ipiv, int &info)
350+
{
351+
zgetrf_(&m, &n, A, &lda, ipiv, &info);
352+
}
353+
354+
static inline
355+
void getri(const int& n, float* A, const int& lda, const int* ipiv, float* work, const int& lwork, int& info)
356+
{
357+
sgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
358+
}
359+
static inline
360+
void getri(const int& n, double* A, const int& lda, const int* ipiv, double* work, const int& lwork, int& info)
361+
{
362+
dgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
363+
}
364+
static inline
365+
void getri(const int& n, std::complex<float>* A, const int& lda, const int* ipiv, std::complex<float>* work, const int& lwork, int& info)
366+
{
367+
cgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
368+
}
369+
static inline
370+
void getri(const int& n, std::complex<double>* A, const int& lda, const int* ipiv, std::complex<double>* work, const int& lwork, int& info)
371+
{
372+
zgetri_(&n, A, &lda, ipiv, work, &lwork, &info);
373+
}
374+
324375
} // namespace lapackConnector
325376
} // namespace container
326377

source/module_hamilt_lcao/module_tddft/bandenergy.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "evolve_elec.h"
44
#include "module_base/lapack_connector.h"
5+
#include "module_base/module_container/ATen/kernels/blas.h"
56
#include "module_base/scalapack_connector.h"
67

78
#include <complex>
@@ -271,6 +272,94 @@ void compute_ekb_tensor(const Parallel_Orbitals* pv,
271272
info = MPI_Allreduce(Eii.data<double>(), ekb.data<double>(), nband, MPI_DOUBLE, MPI_SUM, pv->comm());
272273
}
273274

275+
void compute_ekb_tensor_lapack(const Parallel_Orbitals* pv,
276+
const int nband,
277+
const int nlocal,
278+
const container::Tensor& Htmp,
279+
const container::Tensor& psi_k,
280+
container::Tensor& ekb)
281+
{
282+
// Create Tensor objects for temporary data
283+
container::Tensor tmp1(container::DataType::DT_COMPLEX_DOUBLE,
284+
container::DeviceType::CpuDevice,
285+
container::TensorShape({pv->nloc_wfc})); // tmp1 shape: nlocal * nband
286+
tmp1.zero();
287+
288+
container::Tensor Eij(container::DataType::DT_COMPLEX_DOUBLE,
289+
container::DeviceType::CpuDevice,
290+
container::TensorShape({pv->nloc})); // Eij shape: nlocal * nlocal
291+
// Why not use nband * nband ?????
292+
Eij.zero();
293+
294+
std::complex<double> alpha = {1.0, 0.0};
295+
std::complex<double> beta = {0.0, 0.0};
296+
297+
// Perform matrix multiplication: tmp1 = Htmp * psi_k
298+
container::kernels::blas_gemm<std::complex<double>, container::DEVICE_CPU>()('N',
299+
'N',
300+
nlocal,
301+
nband,
302+
nlocal,
303+
&alpha,
304+
Htmp.data<std::complex<double>>(),
305+
nlocal, // Leading dimension of Htmp
306+
psi_k.data<std::complex<double>>(),
307+
nlocal, // Leading dimension of psi_k
308+
&beta,
309+
tmp1.data<std::complex<double>>(),
310+
nlocal); // Leading dimension of tmp1
311+
312+
// Perform matrix multiplication: Eij = psi_k^dagger * tmp1
313+
container::kernels::blas_gemm<std::complex<double>, container::DEVICE_CPU>()('C',
314+
'N',
315+
nband,
316+
nband,
317+
nlocal,
318+
&alpha,
319+
psi_k.data<std::complex<double>>(),
320+
nlocal, // Leading dimension of psi_k
321+
tmp1.data<std::complex<double>>(),
322+
nlocal, // Leading dimension of tmp1
323+
&beta,
324+
Eij.data<std::complex<double>>(),
325+
nlocal); // Leading dimension of Eij
326+
327+
if (Evolve_elec::td_print_eij >= 0.0)
328+
{
329+
GlobalV::ofs_running
330+
<< "------------------------------------------------------------------------------------------------"
331+
<< std::endl;
332+
GlobalV::ofs_running << " Eij:" << std::endl;
333+
for (int i = 0; i < pv->nrow_bands; i++)
334+
{
335+
for (int j = 0; j < pv->ncol_bands; j++)
336+
{
337+
double aa = 0.0, bb = 0.0;
338+
aa = Eij.data<std::complex<double>>()[i * pv->ncol + j].real();
339+
bb = Eij.data<std::complex<double>>()[i * pv->ncol + j].imag();
340+
if (std::abs(aa) < Evolve_elec::td_print_eij)
341+
aa = 0.0;
342+
if (std::abs(bb) < Evolve_elec::td_print_eij)
343+
bb = 0.0;
344+
if (aa > 0.0 || bb > 0.0)
345+
{
346+
GlobalV::ofs_running << i << " " << j << " " << aa << "+" << bb << "i " << std::endl;
347+
}
348+
}
349+
}
350+
GlobalV::ofs_running << std::endl;
351+
GlobalV::ofs_running
352+
<< "------------------------------------------------------------------------------------------------"
353+
<< std::endl;
354+
}
355+
356+
// Extract diagonal elements of Eij into ekb
357+
for (int i = 0; i < nband; ++i)
358+
{
359+
ekb.data<double>()[i] = Eij.data<std::complex<double>>()[i * nlocal + i].real();
360+
}
361+
}
362+
274363
#endif
275364

276365
} // namespace module_tddft

source/module_hamilt_lcao/module_tddft/bandenergy.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ void compute_ekb_tensor(const Parallel_Orbitals* pv,
3737
const container::Tensor& Htmp,
3838
const container::Tensor& psi_k,
3939
container::Tensor& ekb);
40+
41+
void compute_ekb_tensor_lapack(const Parallel_Orbitals* pv,
42+
const int nband,
43+
const int nlocal,
44+
const container::Tensor& Htmp,
45+
const container::Tensor& psi_k,
46+
container::Tensor& ekb);
4047
#endif
4148
} // namespace module_tddft
4249
#endif

source/module_hamilt_lcao/module_tddft/evolve_elec.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ void Evolve_elec::solve_psi(const int& istep,
3939
ModuleBase::TITLE("Evolve_elec", "solve_psi");
4040
ModuleBase::timer::tick("Evolve_elec", "solve_psi");
4141

42+
const int print_matrix = 0;
43+
// const bool use_tensor = true;
44+
const bool use_tensor = false;
45+
const bool use_lapack = true;
46+
4247
for (int ik = 0; ik < nks; ik++)
4348
{
4449
phm->updateHk(ik);
@@ -58,12 +63,11 @@ void Evolve_elec::solve_psi(const int& istep,
5863
nullptr,
5964
&(ekb(ik, 0)),
6065
htype,
61-
propagator);
66+
propagator,
67+
print_matrix);
6268
}
6369
else if (htype == 1)
6470
{
65-
// const bool use_tensor = true;
66-
const bool use_tensor = false;
6771
if (!use_tensor)
6872
{
6973
evolve_psi(nband,
@@ -76,7 +80,8 @@ void Evolve_elec::solve_psi(const int& istep,
7680
Sk_laststep[ik],
7781
&(ekb(ik, 0)),
7882
htype,
79-
propagator);
83+
propagator,
84+
print_matrix);
8085
// std::cout << "Print ekb: " << std::endl;
8186
// ekb.print(std::cout);
8287
}
@@ -122,18 +127,10 @@ void Evolve_elec::solve_psi(const int& istep,
122127
S_laststep_tensor,
123128
ekb_tensor,
124129
htype,
125-
propagator);
126-
// evolve_psi_tensor(nband,
127-
// nlocal,
128-
// &(para_orb),
129-
// phm,
130-
// psi[0].get_pointer(),
131-
// psi_laststep[0].get_pointer(),
132-
// Hk_laststep[ik],
133-
// Sk_laststep[ik],
134-
// &(ekb(ik, 0)),
135-
// htype,
136-
// propagator);
130+
propagator,
131+
print_matrix,
132+
use_lapack);
133+
137134
// std::cout << "Print ekb tensor: " << std::endl;
138135
// ekb.print(std::cout);
139136

0 commit comments

Comments
 (0)