Skip to content

Commit f50c546

Browse files
committed
add is_first_node parameter for act function
1 parent 3c874a1 commit f50c546

File tree

15 files changed

+87
-30
lines changed

15 files changed

+87
-30
lines changed

source/module_hamilt_general/operator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class Operator
6262
const int npol,
6363
const T* tmpsi_in,
6464
T* tmhpsi,
65-
const int ngk_ik = 0)const {};
65+
const int ngk_ik = 0,
66+
const bool is_first_node = false)const {};
6667

6768
/// developer-friendly interfaces for act() function
6869
/// interface type 2: input and change the Psi-type HPsi

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/ekinetic_op.cu

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,27 @@ template <typename FPTYPE>
1313
__global__ void ekinetic_pw(
1414
const int npw,
1515
const int max_npw,
16+
const bool is_first_node,
1617
const FPTYPE tpiba2,
1718
const FPTYPE* gk2,
1819
thrust::complex<FPTYPE>* hpsi,
1920
const thrust::complex<FPTYPE>* psi)
2021
{
2122
const int block_idx = blockIdx.x;
2223
const int thread_idx = threadIdx.x;
23-
for (int ii = thread_idx; ii < npw; ii+= blockDim.x) {
24-
hpsi[block_idx * max_npw + ii]
25-
+= gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
24+
if(is_first_node)
25+
{
26+
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
27+
{
28+
hpsi[block_idx * max_npw + ii] = gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
29+
}
30+
}
31+
else
32+
{
33+
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
34+
{
35+
hpsi[block_idx * max_npw + ii] += gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
36+
}
2637
}
2738
}
2839

@@ -31,6 +42,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
3142
const int& nband,
3243
const int& npw,
3344
const int& max_npw,
45+
const bool& is_first_node,
3446
const FPTYPE& tpiba2,
3547
const FPTYPE* gk2_ik,
3648
std::complex<FPTYPE>* tmhpsi,
@@ -39,7 +51,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
3951
// denghui implement 20221019
4052
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
4153
ekinetic_pw<FPTYPE><<<nband, THREADS_PER_BLOCK>>>(
42-
npw, max_npw, tpiba2, // control params
54+
npw, max_npw, is_first_node, tpiba2, // control params
4355
gk2_ik, // array of data
4456
reinterpret_cast<thrust::complex<FPTYPE>*>(tmhpsi), // array of data
4557
reinterpret_cast<const thrust::complex<FPTYPE>*>(tmpsi_in)); // array of data

source/module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,39 @@ struct ekinetic_pw_op<FPTYPE, base_device::DEVICE_CPU>
99
const int& nband,
1010
const int& npw,
1111
const int& max_npw,
12+
const bool& is_first_node,
1213
const FPTYPE& tpiba2,
1314
const FPTYPE* gk2_ik,
1415
std::complex<FPTYPE>* tmhpsi,
1516
const std::complex<FPTYPE>* tmpsi_in)
1617
{
18+
if (is_first_node)
19+
{
1720
#ifdef _OPENMP
18-
#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
21+
#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
1922
#endif
20-
for (int ib = 0; ib < nband; ++ib) {
21-
for (int ig = 0; ig < npw; ++ig) {
22-
tmhpsi[ib * max_npw + ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ib * max_npw + ig];
23-
}
23+
for (int ib = 0; ib < nband; ++ib)
24+
{
25+
for (int ig = 0; ig < npw; ++ig)
26+
{
27+
tmhpsi[ib * max_npw + ig] = gk2_ik[ig] * tpiba2 * tmpsi_in[ib * max_npw + ig];
28+
}
29+
}
30+
}
31+
else
32+
{
33+
#ifdef _OPENMP
34+
#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
35+
#endif
36+
for (int ib = 0; ib < nband; ++ib)
37+
{
38+
for (int ig = 0; ig < npw; ++ig)
39+
{
40+
tmhpsi[ib * max_npw + ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ib * max_npw + ig];
41+
}
42+
}
43+
}
2444
}
25-
}
2645
};
2746

2847
template struct ekinetic_pw_op<float, base_device::DEVICE_CPU>;

source/module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct ekinetic_pw_op {
2626
const int& nband,
2727
const int& npw,
2828
const int& max_npw,
29+
const bool& is_first_node,
2930
const FPTYPE& tpiba2,
3031
const FPTYPE* gk2_ik,
3132
std::complex<FPTYPE>* tmhpsi,
@@ -41,6 +42,7 @@ struct ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>
4142
const int& nband,
4243
const int& npw,
4344
const int& max_npw,
45+
const bool& is_first_node,
4446
const FPTYPE& tpiba2,
4547
const FPTYPE* gk2_ik,
4648
std::complex<FPTYPE>* tmhpsi,

source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/ekinetic_op.hip.cu

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,37 @@ template <typename FPTYPE>
1313
__global__ void ekinetic_pw(
1414
const int npw,
1515
const int max_npw,
16+
const bool is_first_node,
1617
const FPTYPE tpiba2,
1718
const FPTYPE* gk2,
1819
thrust::complex<FPTYPE>* hpsi,
1920
const thrust::complex<FPTYPE>* psi)
2021
{
2122
const int block_idx = blockIdx.x;
2223
const int thread_idx = threadIdx.x;
23-
for (int ii = thread_idx; ii < npw; ii+= blockDim.x) {
24-
hpsi[block_idx * max_npw + ii]
25-
+= gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
24+
if(is_first_node)
25+
{
26+
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
27+
{
28+
hpsi[block_idx * max_npw + ii] = gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
29+
}
2630
}
31+
else
32+
{
33+
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
34+
{
35+
hpsi[block_idx * max_npw + ii] += gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
36+
}
37+
}
38+
2739
}
2840

2941
template <typename FPTYPE>
3042
void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* dev,
3143
const int& nband,
3244
const int& npw,
3345
const int& max_npw,
46+
const bool& is_first_node,
3447
const FPTYPE& tpiba2,
3548
const FPTYPE* gk2_ik,
3649
std::complex<FPTYPE>* tmhpsi,
@@ -39,7 +52,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
3952
// denghui implement 20221019
4053
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
4154
hipLaunchKernelGGL(HIP_KERNEL_NAME(ekinetic_pw<FPTYPE>), dim3(nband), dim3(THREADS_PER_BLOCK), 0, 0,
42-
npw, max_npw, tpiba2, // control params
55+
npw, max_npw, is_first_node, tpiba2, // control params
4356
gk2_ik, // array of data
4457
reinterpret_cast<thrust::complex<FPTYPE>*>(tmhpsi), // array of data
4558
reinterpret_cast<const thrust::complex<FPTYPE>*>(tmpsi_in)); // array of data

source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@ void Ekinetic<OperatorPW<T, Device>>::act(
3737
const int npol,
3838
const T* tmpsi_in,
3939
T* tmhpsi,
40-
const int ngk_ik)const
40+
const int ngk_ik,
41+
const bool is_first_node)const
4142
{
4243
ModuleBase::timer::tick("Operator", "EkineticPW");
4344
int max_npw = nbasis / npol;
4445

4546
const Real *gk2_ik = &(this->gk2[this->ik * this->gk2_col]);
4647
// denghui added 20221019
47-
ekinetic_op()(this->ctx, nbands, ngk_ik, max_npw, tpiba2, gk2_ik, tmhpsi, tmpsi_in);
48+
ekinetic_op()(this->ctx, nbands, ngk_ik, max_npw, is_first_node, tpiba2, gk2_ik, tmhpsi, tmpsi_in);
4849
// for (int ib = 0; ib < nbands; ++ib)
4950
// {
5051
// for (int ig = 0; ig < ngk_ik; ++ig)

source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class Ekinetic<OperatorPW<T, Device>> : public OperatorPW<T, Device>
4242
const int npol,
4343
const T* tmpsi_in,
4444
T* tmhpsi,
45-
const int ngk_ik = 0)const override;
45+
const int ngk_ik = 0,
46+
const bool is_first_node = false)const override;
4647

4748
// denghuilu added for copy construct at 20221105
4849
int get_gk2_row() const {return this->gk2_row;}

source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ void Meta<OperatorPW<T, Device>>::act(
4343
const int npol,
4444
const T* tmpsi_in,
4545
T* tmhpsi,
46-
const int ngk_ik)const
46+
const int ngk_ik,
47+
const bool is_first_node)const
4748
{
4849
if (XC_Functional::get_func_type() != 3)
4950
{

source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class Meta<OperatorPW<T, Device>> : public OperatorPW<T, Device>
4343
const int npol,
4444
const T* tmpsi_in,
4545
T* tmhpsi,
46-
const int ngk = 0)const override;
46+
const int ngk_ik = 0,
47+
const bool is_first_node = false)const override;
4748

4849
// denghui added for copy constructor at 20221105
4950
Real get_tpiba() const

source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ void Nonlocal<OperatorPW<T, Device>>::act(
214214
const int npol,
215215
const T* tmpsi_in,
216216
T* tmhpsi,
217-
const int ngk_ik)const
217+
const int ngk_ik,
218+
const bool is_first_node)const
218219
{
219220
ModuleBase::timer::tick("Operator", "NonlocalPW");
220221
if(!PARAM.inp.use_paw)

0 commit comments

Comments
 (0)