Skip to content

Commit 7194eb7

Browse files
Refactor: replace sto_hchi by HamiltSdftPW::hPsi (#5298)
* change sto_hchi to hamilt_sdft_pw * add is_first_node parameter for act function * optimize hPsi * [pre-commit.ci lite] apply automatic fixes * fix compile error * fix compile error and add UTs for hamilt_sdft * fix CUDA compile * fix wrong setmem * [pre-commit.ci lite] apply automatic fixes * fix undefined hspi * fix compile in sdft * optimize for --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 0bac03f commit 7194eb7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+591
-430
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ OBJS_GINT=gint.o\
291291
init_orb.o\
292292

293293
OBJS_HAMILT=hamilt_pw.o\
294+
hamilt_sdft_pw.o\
294295
operator.o\
295296
operator_pw.o\
296297
ekinetic_pw.o\
@@ -648,7 +649,6 @@ OBJS_SRCPW=H_Ewald_pw.o\
648649
structure_factor_k.o\
649650
soc.o\
650651
sto_iter.o\
651-
sto_hchi.o\
652652
sto_che.o\
653653
sto_wf.o\
654654
sto_func.o\

source/module_base/math_chebyshev.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ bool Chebyshev<REAL>::checkconverge(
575575
std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
576576
std::complex<REAL>* wavein,
577577
const int N,
578+
const int LDA,
578579
REAL& tmax,
579580
REAL& tmin,
580581
REAL stept)
@@ -584,9 +585,9 @@ bool Chebyshev<REAL>::checkconverge(
584585
std::complex<REAL>* arrayn;
585586
std::complex<REAL>* arrayn_1;
586587

587-
arraynp1 = new std::complex<REAL>[N];
588-
arrayn = new std::complex<REAL>[N];
589-
arrayn_1 = new std::complex<REAL>[N];
588+
arraynp1 = new std::complex<REAL>[LDA];
589+
arrayn = new std::complex<REAL>[LDA];
590+
arrayn_1 = new std::complex<REAL>[LDA];
590591

591592
ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, N);
592593

source/module_base/math_chebyshev.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class Chebyshev
191191
bool checkconverge(std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
192192
std::complex<REAL>* wavein,
193193
const int N,
194+
const int LDA,
194195
REAL& tmax, // trial number for upper bound
195196
REAL& tmin, // trial number for lower bound
196197
REAL stept); // tmax = max() + stept, tmin = min() - stept

source/module_base/test/math_chebyshev_test.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,15 +346,15 @@ TEST_F(MathChebyshevTest, checkconverge)
346346
double tmin = -1.1;
347347
double tmax = 1.1;
348348
bool converge;
349-
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 0.2);
349+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
350350
EXPECT_TRUE(converge);
351-
converge = p_chetest->checkconverge(fun_sigma_y, v + 2, 2, tmax, tmin, 0.2);
351+
converge = p_chetest->checkconverge(fun_sigma_y, v + 2, 2, 2, tmax, tmin, 0.2);
352352
EXPECT_TRUE(converge);
353353
EXPECT_NEAR(tmin, -1.1, 1e-8);
354354
EXPECT_NEAR(tmax, 1.1, 1e-8);
355355

356356
tmax = -1.1;
357-
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 2.2);
357+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 2.2);
358358
EXPECT_TRUE(converge);
359359
EXPECT_NEAR(tmin, -1.1, 1e-8);
360360
EXPECT_NEAR(tmax, 1.1, 1e-8);
@@ -363,12 +363,12 @@ TEST_F(MathChebyshevTest, checkconverge)
363363
v[0] = std::complex<double>(0, 1), v[1] = 1;
364364
fun.factor = 1.5;
365365
tmin = -1.1, tmax = 1.1;
366-
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 0.2);
366+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
367367
EXPECT_FALSE(converge);
368368

369369
fun.factor = -1.5;
370370
tmin = -1.1, tmax = 1.1;
371-
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 0.2);
371+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
372372
EXPECT_FALSE(converge);
373373
fun.factor = 1;
374374

@@ -632,9 +632,9 @@ TEST_F(MathChebyshevTest, checkconverge_float)
632632

633633
auto fun_sigma_yf
634634
= [&](std::complex<float>* in, std::complex<float>* out, const int m = 1) { fun.sigma_y(in, out, m); };
635-
converge = p_fchetest->checkconverge(fun_sigma_yf, v, 2, tmax, tmin, 0.2);
635+
converge = p_fchetest->checkconverge(fun_sigma_yf, v, 2, 2, tmax, tmin, 0.2);
636636
EXPECT_TRUE(converge);
637-
converge = p_fchetest->checkconverge(fun_sigma_yf, v + 2, 2, tmax, tmin, 0.2);
637+
converge = p_fchetest->checkconverge(fun_sigma_yf, v + 2, 2, 2, tmax, tmin, 0.2);
638638
EXPECT_TRUE(converge);
639639
EXPECT_NEAR(tmin, -1.1, 1e-6);
640640
EXPECT_NEAR(tmax, 1.1, 1e-6);

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell)
133133
void ESolver_SDFT_PW::before_scf(const int istep)
134134
{
135135
ESolver_KS_PW::before_scf(istep);
136+
delete reinterpret_cast<hamilt::HamiltPW<double>*>(this->p_hamilt);
137+
this->p_hamilt = new hamilt::HamiltSdftPW<std::complex<double>>(this->pelec->pot,
138+
this->pw_wfc,
139+
&this->kv,
140+
PARAM.globalv.npol,
141+
&this->stoche.emin_sto,
142+
&this->stoche.emax_sto);
143+
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<std::complex<double>>*>(this->p_hamilt);
144+
136145
if (istep > 0 && PARAM.inp.nbands_sto != 0 && PARAM.inp.initsto_freq > 0 && istep % PARAM.inp.initsto_freq == 0)
137146
{
138147
Update_Sto_Orbitals(this->stowf, PARAM.inp.seed_sto);
@@ -177,7 +186,8 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
177186
this->pw_wfc,
178187
&this->wf,
179188
this->stowf,
180-
this->stoche,
189+
this->stoche,
190+
this->p_hamilt_sto,
181191
PARAM.inp.calculation,
182192
PARAM.inp.basis_type,
183193
PARAM.inp.ks_solver,

source/module_esolver/esolver_sdft_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
#define ESOLVER_SDFT_PW_H
33

44
#include "esolver_ks_pw.h"
5-
#include "module_hamilt_pw/hamilt_stodft/sto_hchi.h"
65
#include "module_hamilt_pw/hamilt_stodft/sto_iter.h"
76
#include "module_hamilt_pw/hamilt_stodft/sto_wf.h"
87
#include "module_hamilt_pw/hamilt_stodft/sto_che.h"
8+
#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h"
99

1010
namespace ModuleESolver
1111
{
@@ -27,6 +27,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW<std::complex<double>>
2727
public:
2828
Stochastic_WF stowf;
2929
StoChe<double> stoche;
30+
hamilt::HamiltSdftPW<std::complex<double>>* p_hamilt_sto = nullptr;
3031

3132
protected:
3233
virtual void before_scf(const int istep) override;

source/module_hamilt_general/operator.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6060
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
6161
}
6262

63-
auto call_act = [&, this](const Operator* op) -> void {
63+
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
6464
// a "psi" with the bands of needed range
6565
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis());
6666
switch (op->get_act_type())
@@ -69,17 +69,17 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6969
op->act(psi_wrapper, *this->hpsi, nbands);
7070
break;
7171
default:
72-
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik));
72+
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
7373
break;
7474
}
7575
};
7676

7777
ModuleBase::timer::tick("Operator", "hPsi");
78-
call_act(this);
78+
call_act(this, true); // first node
7979
Operator* node((Operator*)this->next_op);
8080
while (node != nullptr)
8181
{
82-
call_act(node);
82+
call_act(node, false); // other nodes
8383
node = (Operator*)(node->next_op);
8484
}
8585
ModuleBase::timer::tick("Operator", "hPsi");
@@ -162,7 +162,7 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
162162
size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis();
163163
// ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
164164
// denghui replaced at 20221104
165-
set_memory_op()(this->ctx, hpsi_pointer, 0, total_hpsi_size);
165+
// set_memory_op()(this->ctx, hpsi_pointer, 0, total_hpsi_size);
166166
return hpsi_pointer;
167167
}
168168

source/module_hamilt_general/operator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,14 @@ class Operator
5757
///do operation : |hpsi_choosed> = V|psi_choosed>
5858
///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi
5959
/// interface type 1: pointer-only (default)
60+
/// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!!
6061
virtual void act(const int nbands,
6162
const int nbasis,
6263
const int npol,
6364
const T* tmpsi_in,
6465
T* tmhpsi,
65-
const int ngk_ik = 0)const {};
66+
const int ngk_ik = 0,
67+
const bool is_first_node = false)const {};
6668

6769
/// developer-friendly interfaces for act() function
6870
/// interface type 2: input and change the Psi-type HPsi

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/ekinetic_op.cu

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,32 @@ template <typename FPTYPE>
1313
__global__ void ekinetic_pw(
1414
const int npw,
1515
const int max_npw,
16+
const bool is_first_node,
1617
const FPTYPE tpiba2,
1718
const FPTYPE* gk2,
1819
thrust::complex<FPTYPE>* hpsi,
1920
const thrust::complex<FPTYPE>* psi)
2021
{
2122
const int block_idx = blockIdx.x;
2223
const int thread_idx = threadIdx.x;
23-
for (int ii = thread_idx; ii < npw; ii+= blockDim.x) {
24-
hpsi[block_idx * max_npw + ii]
25-
+= gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
24+
const int start_idx = block_idx * max_npw;
25+
if(is_first_node)
26+
{
27+
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
28+
{
29+
hpsi[start_idx + ii] = gk2[ii] * tpiba2 * psi[start_idx + ii];
30+
}
31+
for (int ii = npw + thread_idx; ii < max_npw; ii += blockDim.x)
32+
{
33+
hpsi[start_idx + ii] = 0.0;
34+
}
35+
}
36+
else
37+
{
38+
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
39+
{
40+
hpsi[start_idx + ii] += gk2[ii] * tpiba2 * psi[start_idx + ii];
41+
}
2642
}
2743
}
2844

@@ -31,6 +47,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
3147
const int& nband,
3248
const int& npw,
3349
const int& max_npw,
50+
const bool& is_first_node,
3451
const FPTYPE& tpiba2,
3552
const FPTYPE* gk2_ik,
3653
std::complex<FPTYPE>* tmhpsi,
@@ -39,7 +56,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
3956
// denghui implement 20221019
4057
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
4158
ekinetic_pw<FPTYPE><<<nband, THREADS_PER_BLOCK>>>(
42-
npw, max_npw, tpiba2, // control params
59+
npw, max_npw, is_first_node, tpiba2, // control params
4360
gk2_ik, // array of data
4461
reinterpret_cast<thrust::complex<FPTYPE>*>(tmhpsi), // array of data
4562
reinterpret_cast<const thrust::complex<FPTYPE>*>(tmpsi_in)); // array of data

source/module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,50 @@ struct ekinetic_pw_op<FPTYPE, base_device::DEVICE_CPU>
99
const int& nband,
1010
const int& npw,
1111
const int& max_npw,
12+
const bool& is_first_node,
1213
const FPTYPE& tpiba2,
1314
const FPTYPE* gk2_ik,
1415
std::complex<FPTYPE>* tmhpsi,
1516
const std::complex<FPTYPE>* tmpsi_in)
1617
{
18+
if (is_first_node)
19+
{
20+
for (int ib = 0; ib < nband; ++ib)
21+
{
1722
#ifdef _OPENMP
18-
#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
23+
#pragma omp parallel for
1924
#endif
20-
for (int ib = 0; ib < nband; ++ib) {
21-
for (int ig = 0; ig < npw; ++ig) {
22-
tmhpsi[ib * max_npw + ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ib * max_npw + ig];
23-
}
25+
for (int ig = 0; ig < npw; ++ig)
26+
{
27+
tmhpsi[ig] = gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
28+
}
29+
#ifdef _OPENMP
30+
#pragma omp parallel for
31+
#endif
32+
for (int ig = npw; ig < max_npw; ++ig)
33+
{
34+
tmhpsi[ig] = 0.0;
35+
}
36+
tmpsi_in += max_npw;
37+
tmhpsi += max_npw;
38+
}
39+
}
40+
else
41+
{
42+
for (int ib = 0; ib < nband; ++ib)
43+
{
44+
#ifdef _OPENMP
45+
#pragma omp parallel for
46+
#endif
47+
for (int ig = 0; ig < npw; ++ig)
48+
{
49+
tmhpsi[ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
50+
}
51+
tmpsi_in += max_npw;
52+
tmhpsi += max_npw;
53+
}
54+
}
2455
}
25-
}
2656
};
2757

2858
template struct ekinetic_pw_op<float, base_device::DEVICE_CPU>;

0 commit comments

Comments
 (0)