Skip to content

Commit 6f89e08

Browse files
Refactor
1 parent f79e698 commit 6f89e08

File tree

5 files changed

+544
-525
lines changed

5 files changed

+544
-525
lines changed

source/source_cell/pseudo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "pseudo.h"
22
#include "source_base/tool_title.h"
3+
#include <cstdint>
34

45
pseudo::pseudo()
56
{

source/source_pw/module_pwdft/operator_pw/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ list(APPEND operator_ks_pw_srcs
66
meta_pw.cpp
77
velocity_pw.cpp
88
onsite_proj_pw.cpp
9+
operator_pw/exx_pw_ace.cpp
10+
operator_pw/exx_pw_pot.cpp
911
)
1012

1113
# this library is included in module_pwdft now
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
#include "op_exx_pw.h"
2+
3+
namespace hamilt
4+
{
5+
template <typename T, typename Device>
6+
void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
7+
const int nbasis,
8+
const int npol,
9+
const T *tmpsi_in,
10+
T *tmhpsi,
11+
const int ngk_ik,
12+
const bool is_first_node) const
13+
{
14+
ModuleBase::timer::tick("OperatorEXXPW", "act_op_ace");
15+
// std::cout << "act_op_ace" << std::endl;
16+
// hpsi += -Xi^\dagger * Xi * psi
17+
T* Xi_ace = Xi_ace_k[this->ik];
18+
int nbands_tot = psi.get_nbands();
19+
int nbasis_max = psi.get_nbasis();
20+
// T* hpsi = nullptr;
21+
// resmem_complex_op()(hpsi, nbands_tot * nbasis);
22+
// setmem_complex_op()(hpsi, 0, nbands_tot * nbasis);
23+
T* Xi_psi = nullptr;
24+
resmem_complex_op()(Xi_psi, nbands_tot * nbands);
25+
setmem_complex_op()(Xi_psi, 0, nbands_tot * nbands);
26+
27+
char trans_N = 'N', trans_T = 'T', trans_C = 'C';
28+
T intermediate_one = 1.0, intermediate_zero = 0.0, intermediate_minus_one = -1.0;
29+
// Xi * psi
30+
gemm_complex_op()(trans_N,
31+
trans_N,
32+
nbands_tot,
33+
nbands,
34+
nbasis,
35+
&intermediate_one,
36+
Xi_ace,
37+
nbands_tot,
38+
tmpsi_in,
39+
nbasis,
40+
&intermediate_zero,
41+
Xi_psi,
42+
nbands_tot
43+
);
44+
45+
Parallel_Reduce::reduce_pool(Xi_psi, nbands_tot * nbands);
46+
47+
// Xi^\dagger * (Xi * psi)
48+
gemm_complex_op()(trans_C,
49+
trans_N,
50+
nbasis,
51+
nbands,
52+
nbands_tot,
53+
&intermediate_minus_one,
54+
Xi_ace,
55+
nbands_tot,
56+
Xi_psi,
57+
nbands_tot,
58+
&intermediate_one,
59+
tmhpsi,
60+
nbasis
61+
);
62+
63+
64+
// // negative sign, add to hpsi
65+
// vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1);
66+
// delmem_complex_op()(hpsi);
67+
delmem_complex_op()(Xi_psi);
68+
ModuleBase::timer::tick("OperatorEXXPW", "act_op_ace");
69+
70+
}
71+
72+
template <typename T, typename Device>
73+
void OperatorEXXPW<T, Device>::construct_ace() const
74+
{
75+
ModuleBase::timer::tick("OperatorEXXPW", "construct_ace");
76+
// int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk();
77+
int nbands = psi.get_nbands();
78+
int nbasis = psi.get_nbasis();
79+
int nk = psi.get_nk();
80+
81+
int ik_save = this->ik;
82+
int * ik_ = const_cast<int*>(&this->ik);
83+
84+
T intermediate_one = 1.0, intermediate_zero = 0.0;
85+
86+
if (h_psi_ace == nullptr)
87+
{
88+
resmem_complex_op()(h_psi_ace, nbands * nbasis);
89+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
90+
}
91+
92+
if (Xi_ace_k.size() != nk)
93+
{
94+
Xi_ace_k.resize(nk);
95+
for (int i = 0; i < nk; i++)
96+
{
97+
resmem_complex_op()(Xi_ace_k[i], nbands * nbasis);
98+
}
99+
}
100+
101+
for (int i = 0; i < nk; i++)
102+
{
103+
setmem_complex_op()(Xi_ace_k[i], 0, nbands * nbasis);
104+
}
105+
106+
if (L_ace == nullptr)
107+
{
108+
resmem_complex_op()(L_ace, nbands * nbands);
109+
setmem_complex_op()(L_ace, 0, nbands * nbands);
110+
}
111+
112+
if (psi_h_psi_ace == nullptr)
113+
{
114+
resmem_complex_op()(psi_h_psi_ace, nbands * nbands);
115+
}
116+
117+
if (first_iter) return;
118+
119+
for (int ik = 0; ik < nk; ik++)
120+
{
121+
int npwk = wfcpw->npwk[ik];
122+
123+
T* Xi_ace = Xi_ace_k[ik];
124+
psi.fix_kb(ik, 0);
125+
T* p_psi = psi.get_pointer();
126+
127+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
128+
129+
*ik_ = ik;
130+
131+
act_op(
132+
nbands,
133+
nbasis,
134+
1,
135+
p_psi,
136+
h_psi_ace,
137+
nbasis,
138+
false
139+
);
140+
141+
// psi_h_psi_ace = psi^\dagger * h_psi_ace
142+
// p_exx_helper->psi.fix_kb(0, 0);
143+
gemm_complex_op()('C',
144+
'N',
145+
nbands,
146+
nbands,
147+
npwk,
148+
&intermediate_one,
149+
p_psi,
150+
nbasis,
151+
h_psi_ace,
152+
nbasis,
153+
&intermediate_zero,
154+
psi_h_psi_ace,
155+
nbands);
156+
157+
// reduction of psi_h_psi_ace, due to distributed memory
158+
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands);
159+
160+
T intermediate_minus_one = -1.0;
161+
axpy_complex_op()(nbands * nbands,
162+
&intermediate_minus_one,
163+
psi_h_psi_ace,
164+
1,
165+
L_ace,
166+
1);
167+
168+
169+
int info = 0;
170+
char up = 'U', lo = 'L';
171+
172+
lapack_potrf()(lo, nbands, L_ace, nbands);
173+
174+
// expand for-loop
175+
for (int i = 0; i < nbands; ++i) {
176+
setmem_complex_op()(L_ace + i * nbands, 0, i);
177+
}
178+
179+
// L_ace inv in place
180+
char non = 'N';
181+
lapack_trtri()(lo, non, nbands, L_ace, nbands);
182+
183+
// Xi_ace = L_ace^-1 * h_psi_ace^dagger
184+
gemm_complex_op()('N',
185+
'C',
186+
nbands,
187+
npwk,
188+
nbands,
189+
&intermediate_one,
190+
L_ace,
191+
nbands,
192+
h_psi_ace,
193+
nbasis,
194+
&intermediate_zero,
195+
Xi_ace,
196+
nbands);
197+
198+
// clear mem
199+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
200+
setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands);
201+
setmem_complex_op()(L_ace, 0, nbands * nbands);
202+
203+
}
204+
205+
*ik_ = ik_save;
206+
ModuleBase::timer::tick("OperatorEXXPW", "construct_ace");
207+
208+
}
209+
210+
template <typename T, typename Device>
211+
double OperatorEXXPW<T, Device>::cal_exx_energy_ace(psi::Psi<T, Device>* ppsi_) const
212+
{
213+
double Eexx = 0;
214+
215+
psi::Psi<T, Device> psi_ = *ppsi_;
216+
int* ik_ = const_cast<int*>(&this->ik);
217+
int ik_save = this->ik;
218+
for (int i = 0; i < wfcpw->nks; i++)
219+
{
220+
setmem_complex_op()(h_psi_ace, 0, psi_.get_nbands() * psi_.get_nbasis());
221+
*ik_ = i;
222+
psi_.fix_kb(i, 0);
223+
T* psi_i = psi_.get_pointer();
224+
act_op_ace(psi_.get_nbands(), psi_.get_nbasis(), 1, psi_i, h_psi_ace, 0, true);
225+
226+
for (int nband = 0; nband < psi_.get_nbands(); nband++)
227+
{
228+
psi_.fix_kb(i, nband);
229+
T* psi_i_n = psi_.get_pointer();
230+
T* hpsi_i_n = h_psi_ace + nband * psi_.get_nbasis();
231+
double wg_i_n = (*wg)(i, nband);
232+
// Eexx += dot(psi_i_n, h_psi_i_n)
233+
Eexx += dot_op()(psi_.get_nbasis(), psi_i_n, hpsi_i_n, false) * wg_i_n * 2;
234+
}
235+
}
236+
237+
Parallel_Reduce::reduce_pool(Eexx);
238+
*ik_ = ik_save;
239+
return Eexx;
240+
}
241+
template class OperatorEXXPW<std::complex<float>, base_device::DEVICE_CPU>;
242+
template class OperatorEXXPW<std::complex<double>, base_device::DEVICE_CPU>;
243+
#if ((defined __CUDA) || (defined __ROCM))
244+
template class OperatorEXXPW<std::complex<float>, base_device::DEVICE_GPU>;
245+
template class OperatorEXXPW<std::complex<double>, base_device::DEVICE_GPU>;
246+
#endif
247+
}

0 commit comments

Comments
 (0)