Skip to content

Commit a30d765

Browse files
committed
refactor pw
1 parent d3fb83f commit a30d765

File tree

8 files changed

+27
-39
lines changed

8 files changed

+27
-39
lines changed

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ PW_Basis_K::~PW_Basis_K()
2222
delete[] igl2isz_k;
2323
delete[] igl2ig_k;
2424
delete[] gk2;
25-
delete[] ig2ixyz_k_;
2625
#if defined(__CUDA) || defined(__ROCM)
2726
if (this->device == "gpu") {
2827
if (this->precision == "single") {
@@ -169,6 +168,7 @@ void PW_Basis_K::setupIndGk()
169168
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks);
170169
}
171170
#endif
171+
this->get_ig2ixyz_k();
172172
return;
173173
}
174174

@@ -334,8 +334,12 @@ int& PW_Basis_K::getigl2ig(const int ik, const int igl) const
334334

335335
void PW_Basis_K::get_ig2ixyz_k()
336336
{
337-
delete[] this->ig2ixyz_k_;
338-
this->ig2ixyz_k_ = new int [this->npwk_max * this->nks];
337+
if (this->device != "gpu")
338+
{
339+
//only GPU need to get ig2ixyz_k
340+
return;
341+
}
342+
int * ig2ixyz_k_cpu = new int [this->npwk_max * this->nks];
339343
ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks);
340344
assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily.
341345
for(int ik = 0; ik < this->nks; ++ik)
@@ -348,15 +352,12 @@ void PW_Basis_K::get_ig2ixyz_k()
348352
int ixy = this->is2fftixy[is];
349353
int iy = ixy % this->ny;
350354
int ix = ixy / this->ny;
351-
ig2ixyz_k_[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
355+
ig2ixyz_k_cpu[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
352356
}
353357
}
354-
#if defined(__CUDA) || defined(__ROCM)
355-
if (this->device == "gpu") {
356-
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
357-
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, this->ig2ixyz_k_, this->npwk_max * this->nks);
358-
}
359-
#endif
358+
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
359+
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks);
360+
delete[] ig2ixyz_k_cpu;
360361
}
361362

362363
std::vector<int> PW_Basis_K::get_ig2ix(const int ik) const

source/module_basis/module_pw/pw_basis_k.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ class PW_Basis_K : public PW_Basis
7171
const bool xprime_in = true
7272
);
7373

74-
void get_ig2ixyz_k();
75-
7674
public:
7775
int nks=0;//number of k points in this pool
7876
ModuleBase::Vector3<double> *kvec_d=nullptr; // Direct coordinates of k points
@@ -88,8 +86,7 @@ class PW_Basis_K : public PW_Basis
8886

8987
int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz)
9088
int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig
91-
int *ig2ixyz_k=nullptr;
92-
int *ig2ixyz_k_=nullptr;
89+
int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz
9390

9491
double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
9592

@@ -108,6 +105,8 @@ class PW_Basis_K : public PW_Basis
108105
double * d_gk2 = nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
109106
//create igl2isz_k map array for fft
110107
void setupIndGk();
108+
// get ig2ixyz_k
109+
void get_ig2ixyz_k();
111110
//calculate G+K, it is a private function
112111
ModuleBase::Vector3<double> cal_GplusK_cartesian(const int ik, const int ig) const;
113112

source/module_basis/module_pw/test/test4-4.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,6 @@ TEST_F(PWTEST,test4_4)
213213
}
214214
}
215215

216-
//check getig2ixyz_k
217-
pwtest.get_ig2ixyz_k();
218-
for(int igl = 0; igl < npwk ; ++igl)
219-
{
220-
EXPECT_GE(pwtest.ig2ixyz_k_[igl + ik * pwtest.npwk_max], 0);
221-
}
222-
223216
}
224217
delete []tmp;
225218
delete [] rhor;

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,6 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
138138
{
139139
// 1) call before_all_runners() of ESolver_KS
140140
ESolver_KS<T, Device>::before_all_runners(ucell, inp);
141-
#if defined(__CUDA) || defined(__ROCM)
142-
if (PARAM.inp.device == "gpu")
143-
{
144-
this->pw_wfc->get_ig2ixyz_k();
145-
}
146-
#endif
147141

148142
// 3) initialize ElecState,
149143
if (this->pelec == nullptr)
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
etotref -3937.5288918827272937
2-
etotperatomref -1968.7644459414
3-
totalforceref 21308.015678
4-
totalstressref 8138673.223144
1+
etotref -4869.74705201
2+
etotperatomref -2434.87352600
3+
totalforceref 5.19483000
4+
totalstressref 37241.44843500
55
pointgroupref C_1
66
spacegroupref C_1
77
nksibzref 8
8-
totaltimeref 10.10
8+
totaltimeref 10.37
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
etotref -378.4158765482854960
1+
etotref -378.4158765482866329
22
etotperatomref -126.1386255161
3-
totalforceref 1005.225100
4-
totalstressref 2123.397615
3+
totalforceref 1005.225123
4+
totalstressref 2123.397548
55
pointgroupref C_2v
66
spacegroupref C_2v
77
nksibzref 1
8-
totaltimeref 0.56
8+
totaltimeref 1.05

tests/integrate/121_PW_kspacing/INPUT

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ smearing_sigma 0.020000
1717

1818
kspacing 0.6
1919
pseudo_dir ../../PP_ORB
20+
pw_seed 1
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
etotref -31.02022530858069
2-
etotperatomref -15.5101126543
1+
etotref -31.02022535682493
2+
etotperatomref -15.5101126784
33
pointgroupref D_2h
44
spacegroupref D_2h
55
nksibzref 4
6-
totaltimeref 1.53
6+
totaltimeref 3.04

0 commit comments

Comments
 (0)