Skip to content

Commit 449fecc

Browse files
committed
optimize for
1 parent 033ce33 commit 449fecc

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/ekinetic_op.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,23 @@ __global__ void ekinetic_pw(
2121
{
2222
const int block_idx = blockIdx.x;
2323
const int thread_idx = threadIdx.x;
24+
const int start_idx = block_idx * max_npw;
2425
if(is_first_node)
2526
{
2627
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
2728
{
28-
hpsi[block_idx * max_npw + ii] = gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
29+
hpsi[start_idx + ii] = gk2[ii] * tpiba2 * psi[start_idx + ii];
2930
}
3031
for (int ii = npw + thread_idx; ii < max_npw; ii += blockDim.x)
3132
{
32-
hpsi[block_idx * max_npw + ii] = 0.0;
33+
hpsi[start_idx + ii] = 0.0;
3334
}
3435
}
3536
else
3637
{
3738
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
3839
{
39-
hpsi[block_idx * max_npw + ii] += gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
40+
hpsi[start_idx + ii] += gk2[ii] * tpiba2 * psi[start_idx + ii];
4041
}
4142
}
4243
}

source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/ekinetic_op.hip.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,23 @@ __global__ void ekinetic_pw(
2121
{
2222
const int block_idx = blockIdx.x;
2323
const int thread_idx = threadIdx.x;
24+
const int start_idx = block_idx * max_npw;
2425
if(is_first_node)
2526
{
2627
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
2728
{
28-
hpsi[block_idx * max_npw + ii] = gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
29+
hpsi[start_idx + ii] = gk2[ii] * tpiba2 * psi[start_idx + ii];
2930
}
3031
for (int ii = npw + thread_idx; ii < max_npw; ii += blockDim.x)
3132
{
32-
hpsi[block_idx * max_npw + ii] = 0.0;
33+
hpsi[start_idx + ii] = 0.0;
3334
}
3435
}
3536
else
3637
{
3738
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
3839
{
39-
hpsi[block_idx * max_npw + ii] += gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
40+
hpsi[start_idx + ii] += gk2[ii] * tpiba2 * psi[start_idx + ii];
4041
}
4142
}
4243

source/module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ void HamiltSdftPW<T, Device>::hPsi_norm(const T* psi_in, T* hpsi_norm, const int
5959
{
6060
for (int ig = 0; ig < npwk; ++ig)
6161
{
62-
hpsi_norm[ib * npwk_max + ig]
63-
= (hpsi_norm[ib * npwk_max + ig] - Ebar * psi_in[ib * npwk_max + ig]) / DeltaE;
62+
hpsi_norm[ig] = (hpsi_norm[ig] - Ebar * psi_in[ig]) / DeltaE;
6463
}
64+
hpsi_norm += npwk_max;
65+
psi_in += npwk_max;
6566
}
6667
ModuleBase::timer::tick("HamiltSdftPW", "hPsi_norm");
6768
}

0 commit comments

Comments
 (0)