Skip to content

Commit 31cad58

Browse files
committed
optimize hPsi
1 parent f50c546 commit 31cad58

File tree

7 files changed

+31
-13
lines changed

7 files changed

+31
-13
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6060
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
6161
}
6262

63-
auto call_act = [&, this](const Operator* op) -> void {
63+
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
6464
// a "psi" with the bands of needed range
6565
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis());
6666
switch (op->get_act_type())
@@ -69,17 +69,17 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6969
op->act(psi_wrapper, *this->hpsi, nbands);
7070
break;
7171
default:
72-
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik));
72+
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
7373
break;
7474
}
7575
};
7676

7777
ModuleBase::timer::tick("Operator", "hPsi");
78-
call_act(this);
78+
call_act(this, true); // first node
7979
Operator* node((Operator*)this->next_op);
8080
while (node != nullptr)
8181
{
82-
call_act(node);
82+
call_act(node, false); // other nodes
8383
node = (Operator*)(node->next_op);
8484
}
8585
ModuleBase::timer::tick("Operator", "hPsi");
@@ -162,7 +162,7 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
162162
size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis();
163163
// ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
164164
// denghui replaced at 20221104
165-
set_memory_op()(this->ctx, hpsi_pointer, 0, total_hpsi_size);
165+
// set_memory_op()(this->ctx, hpsi_pointer, 0, total_hpsi_size);
166166
return hpsi_pointer;
167167
}
168168

source/module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,32 @@ struct ekinetic_pw_op<FPTYPE, base_device::DEVICE_CPU>
1717
{
1818
if (is_first_node)
1919
{
20-
#ifdef _OPENMP
21-
#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
22-
#endif
2320
for (int ib = 0; ib < nband; ++ib)
2421
{
22+
#ifdef _OPENMP
23+
#pragma omp parallel for
24+
#endif
2525
for (int ig = 0; ig < npw; ++ig)
2626
{
27-
tmhpsi[ib * max_npw + ig] = gk2_ik[ig] * tpiba2 * tmpsi_in[ib * max_npw + ig];
27+
tmhpsi[ig] = gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
2828
}
29+
tmpsi_in += max_npw;
30+
tmhpsi += max_npw;
2931
}
3032
}
3133
else
3234
{
33-
#ifdef _OPENMP
34-
#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
35-
#endif
3635
for (int ib = 0; ib < nband; ++ib)
3736
{
37+
#ifdef _OPENMP
38+
#pragma omp parallel for
39+
#endif
3840
for (int ig = 0; ig < npw; ++ig)
3941
{
40-
tmhpsi[ib * max_npw + ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ib * max_npw + ig];
42+
tmhpsi[ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
4143
}
44+
tmpsi_in += max_npw;
45+
tmhpsi += max_npw;
4246
}
4347
}
4448
}

source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ void Meta<OperatorPW<T, Device>>::act(
5252
}
5353

5454
ModuleBase::timer::tick("Operator", "MetaPW");
55+
if(is_first_node)
56+
{
57+
setmem_complex_op()(this->ctx, tmhpsi, 0, nbasis*nbands);
58+
}
5559

5660
const int current_spin = this->isk[this->ik];
5761
int max_npw = nbasis / npol;

source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class Meta<OperatorPW<T, Device>> : public OperatorPW<T, Device>
8484
using vector_mul_vector_op = hsolver::vector_mul_vector_op<T, Device>;
8585
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
8686
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
87+
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
8788
};
8889

8990
} // namespace hamilt

source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ void Nonlocal<OperatorPW<T, Device>>::act(
218218
const bool is_first_node)const
219219
{
220220
ModuleBase::timer::tick("Operator", "NonlocalPW");
221+
if(is_first_node)
222+
{
223+
setmem_complex_op()(this->ctx, tmhpsi, 0, nbasis*nbands);
224+
}
221225
if(!PARAM.inp.use_paw)
222226
{
223227
this->npw = ngk_ik;

source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ void Veff<OperatorPW<T, Device>>::act(
4545
const bool is_first_node)const
4646
{
4747
ModuleBase::timer::tick("Operator", "VeffPW");
48+
if(is_first_node)
49+
{
50+
setmem_complex_op()(this->ctx, tmhpsi, 0, nbasis*nbands);
51+
}
4852

4953
int max_npw = nbasis / npol;
5054
const int current_spin = this->isk[this->ik];

source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class Veff<OperatorPW<T, Device>> : public OperatorPW<T, Device>
7373

7474
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
7575
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
76+
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
7677
};
7778

7879
} // namespace hamilt

0 commit comments

Comments
 (0)