Skip to content

Commit eee8b75

Browse files
committed
Phase 1 of RT-TDDFT GPU Acceleration: Rewriting existing code using Tensor
1 parent 28df43d commit eee8b75

File tree

15 files changed

+1193
-50
lines changed

15 files changed

+1193
-50
lines changed

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
#include "module_io/print_info.h"
2525

2626
//-----HSolver ElecState Hamilt--------
27+
#include "module_elecstate/cal_ux.h"
2728
#include "module_elecstate/elecstate_lcao.h"
2829
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
2930
#include "module_hsolver/hsolver_lcao.h"
3031
#include "module_parameter/parameter.h"
3132
#include "module_psi/psi.h"
32-
#include "module_elecstate/cal_ux.h"
3333

3434
//-----force& stress-------------------
3535
#include "module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h"
@@ -290,7 +290,12 @@ void ESolver_KS_LCAO_TDDFT::after_scf(UnitCell& ucell, const int istep)
290290
{
291291
std::stringstream ss_dipole;
292292
ss_dipole << PARAM.globalv.global_out_dir << "SPIN" << is + 1 << "_DIPOLE";
293-
ModuleIO::write_dipole(ucell,pelec->charge->rho_save[is], pelec->charge->rhopw, is, istep, ss_dipole.str());
293+
ModuleIO::write_dipole(ucell,
294+
pelec->charge->rho_save[is],
295+
pelec->charge->rhopw,
296+
is,
297+
istep,
298+
ss_dipole.str());
294299
}
295300
}
296301
if (TD_Velocity::out_current == true)

source/module_hamilt_lcao/module_tddft/bandenergy.cpp

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#include "bandenergy.h"
22

3-
#include <complex>
4-
#include <iostream>
5-
63
#include "evolve_elec.h"
74
#include "module_base/lapack_connector.h"
85
#include "module_base/scalapack_connector.h"
96

7+
#include <complex>
8+
#include <iostream>
9+
1010
namespace module_tddft
1111
{
1212
#ifdef __MPI
@@ -133,14 +133,144 @@ void compute_ekb(const Parallel_Orbitals* pv,
133133
}
134134
}
135135
} // loop ipcol
136-
} // loop iprow
136+
} // loop iprow
137137
info = MPI_Allreduce(Eii, ekb, nband, MPI_DOUBLE, MPI_SUM, pv->comm());
138138

139139
delete[] tmp1;
140140
delete[] Eij;
141141
delete[] Eii;
142142
}
143143

144+
void compute_ekb_tensor(const Parallel_Orbitals* pv,
145+
const int nband,
146+
const int nlocal,
147+
const container::Tensor& Htmp,
148+
const container::Tensor& psi_k,
149+
container::Tensor& ekb)
150+
{
151+
// Create Tensor objects for temporary data
152+
container::Tensor tmp1(container::DataType::DT_COMPLEX_DOUBLE,
153+
container::DeviceType::CpuDevice,
154+
container::TensorShape({pv->nloc_wfc}));
155+
tmp1.zero();
156+
157+
container::Tensor Eij(container::DataType::DT_COMPLEX_DOUBLE,
158+
container::DeviceType::CpuDevice,
159+
container::TensorShape({pv->nloc}));
160+
Eij.zero();
161+
162+
// Perform matrix multiplication: tmp1 = Htmp * psi_k
163+
ScalapackConnector::gemm('N',
164+
'N',
165+
nlocal,
166+
nband,
167+
nlocal,
168+
1.0,
169+
Htmp.data<std::complex<double>>(),
170+
1,
171+
1,
172+
pv->desc,
173+
psi_k.data<std::complex<double>>(),
174+
1,
175+
1,
176+
pv->desc_wfc,
177+
0.0,
178+
tmp1.data<std::complex<double>>(),
179+
1,
180+
1,
181+
pv->desc_wfc);
182+
183+
// Perform matrix multiplication: Eij = psi_k^dagger * tmp1
184+
ScalapackConnector::gemm('C',
185+
'N',
186+
nband,
187+
nband,
188+
nlocal,
189+
1.0,
190+
psi_k.data<std::complex<double>>(),
191+
1,
192+
1,
193+
pv->desc_wfc,
194+
tmp1.data<std::complex<double>>(),
195+
1,
196+
1,
197+
pv->desc_wfc,
198+
0.0,
199+
Eij.data<std::complex<double>>(),
200+
1,
201+
1,
202+
pv->desc_Eij);
203+
204+
if (Evolve_elec::td_print_eij >= 0.0)
205+
{
206+
GlobalV::ofs_running
207+
<< "------------------------------------------------------------------------------------------------"
208+
<< std::endl;
209+
GlobalV::ofs_running << " Eij:" << std::endl;
210+
for (int i = 0; i < pv->nrow_bands; i++)
211+
{
212+
for (int j = 0; j < pv->ncol_bands; j++)
213+
{
214+
double aa, bb;
215+
aa = Eij.data<std::complex<double>>()[i * pv->ncol + j].real();
216+
bb = Eij.data<std::complex<double>>()[i * pv->ncol + j].imag();
217+
if (std::abs(aa) < Evolve_elec::td_print_eij)
218+
aa = 0.0;
219+
if (std::abs(bb) < Evolve_elec::td_print_eij)
220+
bb = 0.0;
221+
if (aa > 0.0 || bb > 0.0)
222+
{
223+
GlobalV::ofs_running << i << " " << j << " " << aa << "+" << bb << "i " << std::endl;
224+
}
225+
}
226+
}
227+
GlobalV::ofs_running << std::endl;
228+
GlobalV::ofs_running
229+
<< "------------------------------------------------------------------------------------------------"
230+
<< std::endl;
231+
}
232+
233+
int info;
234+
int naroc[2];
235+
236+
// Create a Tensor for Eii
237+
container::Tensor Eii(container::DataType::DT_DOUBLE,
238+
container::DeviceType::CpuDevice,
239+
container::TensorShape({nband}));
240+
Eii.zero();
241+
242+
for (int iprow = 0; iprow < pv->dim0; ++iprow)
243+
{
244+
for (int ipcol = 0; ipcol < pv->dim1; ++ipcol)
245+
{
246+
if (iprow == pv->coord[0] && ipcol == pv->coord[1])
247+
{
248+
naroc[0] = pv->nrow;
249+
naroc[1] = pv->ncol;
250+
for (int j = 0; j < naroc[1]; ++j)
251+
{
252+
int igcol = globalIndex(j, pv->nb, pv->dim1, ipcol);
253+
if (igcol >= nband)
254+
continue;
255+
for (int i = 0; i < naroc[0]; ++i)
256+
{
257+
int igrow = globalIndex(i, pv->nb, pv->dim0, iprow);
258+
if (igrow >= nband)
259+
continue;
260+
if (igcol == igrow)
261+
{
262+
Eii.data<double>()[igcol] = Eij.data<std::complex<double>>()[j * naroc[0] + i].real();
263+
}
264+
}
265+
}
266+
}
267+
} // loop ipcol
268+
} // loop iprow
269+
270+
// Perform MPI reduction to compute ekb
271+
info = MPI_Allreduce(Eii.data<double>(), ekb.data<double>(), nband, MPI_DOUBLE, MPI_SUM, pv->comm());
272+
}
273+
144274
#endif
145275

146276
} // namespace module_tddft

source/module_hamilt_lcao/module_tddft/bandenergy.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#ifndef BANDENERGY_H
77
#define BANDENERGY_H
88

9+
#include "module_base/module_container/ATen/core/tensor.h" // container::Tensor
910
#include "module_basis/module_ao/parallel_orbitals.h"
1011

1112
#include <complex>
@@ -29,6 +30,13 @@ void compute_ekb(const Parallel_Orbitals* pv,
2930
const std::complex<double>* Htmp,
3031
const std::complex<double>* psi_k,
3132
double* ekb);
33+
34+
void compute_ekb_tensor(const Parallel_Orbitals* pv,
35+
const int nband,
36+
const int nlocal,
37+
const container::Tensor& Htmp,
38+
const container::Tensor& psi_k,
39+
container::Tensor& ekb);
3240
#endif
3341
} // namespace module_tddft
3442
#endif

source/module_hamilt_lcao/module_tddft/evolve_elec.cpp

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
namespace module_tddft
1212
{
13-
Evolve_elec::Evolve_elec(){};
14-
Evolve_elec::~Evolve_elec(){};
13+
Evolve_elec::Evolve_elec() {};
14+
Evolve_elec::~Evolve_elec() {};
1515

1616
double Evolve_elec::td_force_dt;
1717
bool Evolve_elec::td_vext;
@@ -73,6 +73,67 @@ void Evolve_elec::solve_psi(const int& istep,
7373
&(ekb(ik, 0)),
7474
htype,
7575
propagator);
76+
77+
const bool use_tensor = false;
78+
if (use_tensor)
79+
{
80+
std::cout << "Print ekb: " << std::endl;
81+
ekb.print(std::cout);
82+
std::cout << "nband = " << nband << std::endl;
83+
std::cout << "psi->get_nbands() = " << psi->get_nbands() << std::endl;
84+
std::cout << "nlocal = " << nlocal << std::endl;
85+
std::cout << "psi->get_nbasis() = " << psi->get_nbasis() << std::endl;
86+
std::cout << "ekb.nr = " << ekb.nr << std::endl;
87+
std::cout << "ekb.nc = " << ekb.nc << std::endl;
88+
89+
// Create TensorMap for psi_k, psi_k_laststep, H_laststep, S_laststep, ekb
90+
container::TensorMap psi_k_tensor(psi[0].get_pointer(),
91+
container::DataType::DT_COMPLEX_DOUBLE,
92+
container::DeviceType::CpuDevice,
93+
container::TensorShape({psi->get_nbands(), psi->get_nbasis()}));
94+
container::TensorMap psi_k_laststep_tensor(
95+
psi_laststep[0].get_pointer(),
96+
container::DataType::DT_COMPLEX_DOUBLE,
97+
container::DeviceType::CpuDevice,
98+
container::TensorShape({psi->get_nbands(), psi->get_nbasis()}));
99+
container::TensorMap H_laststep_tensor(Hk_laststep[ik],
100+
container::DataType::DT_COMPLEX_DOUBLE,
101+
container::DeviceType::CpuDevice,
102+
container::TensorShape({para_orb.nloc}));
103+
container::TensorMap S_laststep_tensor(Sk_laststep[ik],
104+
container::DataType::DT_COMPLEX_DOUBLE,
105+
container::DeviceType::CpuDevice,
106+
container::TensorShape({para_orb.nloc}));
107+
container::TensorMap ekb_tensor(&(ekb(ik, 0)),
108+
container::DataType::DT_DOUBLE,
109+
container::DeviceType::CpuDevice,
110+
container::TensorShape({nband}));
111+
112+
evolve_psi_tensor(nband,
113+
nlocal,
114+
&(para_orb),
115+
phm,
116+
psi_k_tensor,
117+
psi_k_laststep_tensor,
118+
H_laststep_tensor,
119+
S_laststep_tensor,
120+
ekb_tensor,
121+
htype,
122+
propagator);
123+
// evolve_psi_tensor(nband,
124+
// nlocal,
125+
// &(para_orb),
126+
// phm,
127+
// psi[0].get_pointer(),
128+
// psi_laststep[0].get_pointer(),
129+
// Hk_laststep[ik],
130+
// Sk_laststep[ik],
131+
// &(ekb(ik, 0)),
132+
// htype,
133+
// propagator);
134+
std::cout << "Print ekb tensor: " << std::endl;
135+
ekb.print(std::cout);
136+
}
76137
}
77138
else
78139
{

source/module_hamilt_lcao/module_tddft/evolve_elec.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include "module_base/global_function.h"
55
#include "module_base/global_variable.h"
6+
#include "module_base/module_container/ATen/core/tensor.h" // container::Tensor
7+
#include "module_base/module_container/ATen/core/tensor_map.h" // TensorMap
68
#include "module_esolver/esolver_ks_lcao.h"
79
#include "module_esolver/esolver_ks_lcao_tddft.h"
810
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"

0 commit comments

Comments
 (0)