Skip to content

Commit 54e445f

Browse files
EXX KPAR WORKS on NSPIN=2
1 parent 80fd973 commit 54e445f

File tree

4 files changed

+272
-175
lines changed

4 files changed

+272
-175
lines changed

source/source_pw/module_pwdft/operator_pw/exx_pw_ace.cpp

Lines changed: 137 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -116,150 +116,161 @@ void OperatorEXXPW<T, Device>::construct_ace() const
116116
ModuleBase::timer::tick("OperatorEXXPW", "construct_ace");
117117

118118
int nk_max = kv->para_k.get_max_nks_pool();
119-
for (int ik = 0; ik < nk_max; ik++)
119+
int nspin_fac = PARAM.inp.nspin == 2 ? 2 : 1;
120+
for (int ispin = 0; ispin < nspin_fac; ispin++)
120121
{
121-
int npwk = wfcpw->npwk[ik];
122+
for (int ik0 = 0; ik0 < nk_max; ik0++)
123+
{
124+
int ik = ik0 + ispin * nk_max;
125+
int npwk = wfcpw->npwk[ik];
122126

123-
T* Xi_ace = Xi_ace_k[ik];
124-
psi.fix_kb(ik, 0);
125-
T* p_psi = psi.get_pointer();
127+
T* Xi_ace = Xi_ace_k[ik];
128+
psi.fix_kb(ik, 0);
129+
T* p_psi = psi.get_pointer();
126130

127-
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
131+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
128132

129-
setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max);
130-
setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx);
131-
setmem_complex_op()(density_real, 0, rhopw_dev->nrxx);
132-
setmem_complex_op()(density_recip, 0, rhopw_dev->npw);
133-
setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx);
134-
setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx);
135-
int nqs = kv->get_nkstot_full();
133+
setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max);
134+
setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx);
135+
setmem_complex_op()(density_real, 0, rhopw_dev->nrxx);
136+
setmem_complex_op()(density_recip, 0, rhopw_dev->npw);
137+
setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx);
138+
setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx);
139+
int nqs = kv->get_nkstot_full();
136140

137-
bool skip_ik = false;
138-
if (ik >= wfcpw->nks)
139-
{
140-
skip_ik = true;
141-
}
142-
if (skip_ik)
143-
{
144-
// ik fixed here, select band n
145-
for (int iq = 0; iq < nqs; iq++)
141+
bool skip_ik = false;
142+
if (ik >= wfcpw->nks)
143+
{
144+
skip_ik = true;
145+
}
146+
if (skip_ik)
146147
{
147-
// for \psi_nk, get the pw of iq and band m
148-
get_exx_potential<Real, Device>(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega, ik, iq);
148+
// ik fixed here, select band n
149+
for (int iq0 = 0; iq0 < nqs; iq0++)
150+
{
151+
int iq = iq0 + ik;
152+
// for \psi_nk, get the pw of iq and band m
153+
get_exx_potential<Real, Device>(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega, ik, iq);
149154

150-
// decide which pool does the iq belong to
151-
int iq_pool = kv->para_k.whichpool[iq];
152-
int iq_loc = iq - kv->para_k.startk_pool[iq_pool];
155+
// decide which pool does the iq belong to
156+
int iq_pool = kv->para_k.whichpool[iq0];
157+
int iq_loc = iq - kv->para_k.startk_pool[iq_pool];
153158

154-
for (int m_iband = 0; m_iband < psi.get_nbands(); m_iband++)
155-
{
156-
double wg_mqb = 0;
157-
bool skip = false;
158-
if (iq_pool == GlobalV::MY_POOL)
159+
for (int m_iband = 0; m_iband < psi.get_nbands(); m_iband++)
159160
{
160-
wg_mqb = (*wg)(iq_loc, m_iband);
161-
}
161+
double wg_mqb = 0;
162+
bool skip = false;
163+
if (iq_pool == GlobalV::MY_POOL)
164+
{
165+
wg_mqb = (*wg)(iq_loc, m_iband);
166+
}
162167

163-
MPI_Bcast(&wg_mqb, 1, MPI_DOUBLE, kv->para_k.get_startpro_pool(iq_pool), MPI_COMM_WORLD);
168+
MPI_Bcast(&wg_mqb, 1, MPI_DOUBLE, kv->para_k.get_startpro_pool(iq_pool), MPI_COMM_WORLD);
164169

165-
if (wg_mqb < 1e-12)
166-
continue;
170+
if (wg_mqb < 1e-12)
171+
continue;
167172

168-
if (iq_pool == GlobalV::MY_POOL)
169-
{
170-
const T* psi_mq = get_pw(m_iband, iq_loc);
171-
wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc);
172-
// send
173-
}
174-
// if (iq == 0)
175-
// std::cout << "Bcast psi_mq_real" << std::endl;
176-
MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
173+
if (iq_pool == GlobalV::MY_POOL)
174+
{
175+
const T* psi_mq = get_pw(m_iband, iq_loc);
176+
wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc);
177+
// send
178+
}
179+
// if (iq == 0)
180+
// std::cout << "Bcast psi_mq_real" << std::endl;
181+
MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
177182

178183

179-
} // end of iq
184+
} // end of iq
180185

186+
}
181187
}
182-
}
183-
else
184-
{
185-
*ik_ = ik;
186-
act_op(nbands, nbasis, 1, p_psi, h_psi_ace, nbasis, false);
187-
// psi_h_psi_ace = psi^\dagger * h_psi_ace
188-
// p_exx_helper->psi.fix_kb(0, 0);
189-
gemm_complex_op()('C',
190-
'N',
191-
nbands,
192-
nbands,
193-
npwk,
194-
&intermediate_one,
195-
p_psi,
196-
nbasis,
197-
h_psi_ace,
198-
nbasis,
199-
&intermediate_zero,
200-
psi_h_psi_ace,
201-
nbands);
202-
203-
// reduction of psi_h_psi_ace, due to distributed memory
204-
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands);
205-
206-
T intermediate_minus_one = -1.0;
207-
axpy_complex_op()(nbands * nbands,
208-
&intermediate_minus_one,
209-
psi_h_psi_ace,
210-
1,
211-
L_ace,
212-
1);
213-
214-
215-
int info = 0;
216-
char up = 'U', lo = 'L';
217-
//
218-
// for (int i = 0; i < nbands; ++i)
219-
// {
220-
// for (int j = 0; j < nbands; ++j)
221-
// {
222-
// {
223-
// std::cout << psi_h_psi_ace[i * nbands + j] << " ";
224-
// }
225-
// }
226-
// std::cout << std::endl;
227-
// }
228-
// MPI_Barrier(MPI_COMM_WORLD);
229-
// MPI_Abort(MPI_COMM_WORLD, 0);
230-
231-
lapack_potrf()(lo, nbands, L_ace, nbands);
232-
233-
// expand for-loop
234-
for (int i = 0; i < nbands; ++i) {
235-
setmem_complex_op()(L_ace + i * nbands, 0, i);
188+
else
189+
{
190+
*ik_ = ik;
191+
act_op_kpar(nbands, nbasis, 1, p_psi, h_psi_ace, nbasis, false);
192+
// psi_h_psi_ace = psi^\dagger * h_psi_ace
193+
// p_exx_helper->psi.fix_kb(0, 0);
194+
gemm_complex_op()('C',
195+
'N',
196+
nbands,
197+
nbands,
198+
npwk,
199+
&intermediate_one,
200+
p_psi,
201+
nbasis,
202+
h_psi_ace,
203+
nbasis,
204+
&intermediate_zero,
205+
psi_h_psi_ace,
206+
nbands);
207+
208+
// reduction of psi_h_psi_ace, due to distributed memory
209+
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands);
210+
211+
T intermediate_minus_one = -1.0;
212+
axpy_complex_op()(nbands * nbands,
213+
&intermediate_minus_one,
214+
psi_h_psi_ace,
215+
1,
216+
L_ace,
217+
1);
218+
219+
220+
int info = 0;
221+
char up = 'U', lo = 'L';
222+
223+
// for (int i = 0; i < nbands; ++i)
224+
// {
225+
// for (int j = 0; j < nbands; ++j)
226+
// {
227+
// // std::cout << L_ace[i * nbands + j]. << " ";
228+
// if (L_ace[i * nbands + j].imag() >= 0.0)
229+
// {
230+
// std::cout << L_ace[i * nbands + j].real() << "+" << L_ace[i * nbands + j].imag() << "im ";
231+
// }
232+
// else
233+
// {
234+
// std::cout << L_ace[i * nbands + j].real() << L_ace[i * nbands + j].imag() << "im ";
235+
// }
236+
// }
237+
// std::cout << ";" << std::endl;
238+
// }
239+
// MPI_Barrier(MPI_COMM_WORLD);
240+
// MPI_Abort(MPI_COMM_WORLD, 0);
241+
242+
lapack_potrf()(lo, nbands, L_ace, nbands);
243+
244+
// expand for-loop
245+
for (int i = 0; i < nbands; ++i) {
246+
setmem_complex_op()(L_ace + i * nbands, 0, i);
247+
}
248+
249+
// L_ace inv in place
250+
char non = 'N';
251+
lapack_trtri()(lo, non, nbands, L_ace, nbands);
252+
253+
// Xi_ace = L_ace^-1 * h_psi_ace^dagger
254+
gemm_complex_op()('N',
255+
'C',
256+
nbands,
257+
npwk,
258+
nbands,
259+
&intermediate_one,
260+
L_ace,
261+
nbands,
262+
h_psi_ace,
263+
nbasis,
264+
&intermediate_zero,
265+
Xi_ace,
266+
nbands);
267+
268+
// clear mem
269+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
270+
setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands);
271+
setmem_complex_op()(L_ace, 0, nbands * nbands);
236272
}
237-
238-
// L_ace inv in place
239-
char non = 'N';
240-
lapack_trtri()(lo, non, nbands, L_ace, nbands);
241-
242-
// Xi_ace = L_ace^-1 * h_psi_ace^dagger
243-
gemm_complex_op()('N',
244-
'C',
245-
nbands,
246-
npwk,
247-
nbands,
248-
&intermediate_one,
249-
L_ace,
250-
nbands,
251-
h_psi_ace,
252-
nbasis,
253-
&intermediate_zero,
254-
Xi_ace,
255-
nbands);
256-
257-
// clear mem
258-
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
259-
setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands);
260-
setmem_complex_op()(L_ace, 0, nbands * nbands);
261273
}
262-
263274
}
264275

265276
*ik_ = ik_save;

source/source_pw/module_pwdft/operator_pw/exx_pw_pot.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ double exx_divergence(Conv_Coulomb_Pot_K::Coulomb_Type coulomb_type,
404404

405405
// this is the \sum_q F(q) part
406406
// temporarily for all k points, should be replaced to q points later
407-
for (int ik = 0; ik < wfcpw->nks; ik++)
407+
for (int ik = 0; ik < wfcpw->nks / nk_fac; ik++)
408408
{
409409
const ModuleBase::Vector3<double> k_c = wfcpw->kvec_c[ik];
410410
const ModuleBase::Vector3<double> k_d = wfcpw->kvec_d[ik];
@@ -497,9 +497,9 @@ double exx_divergence(Conv_Coulomb_Pot_K::Coulomb_Type coulomb_type,
497497
aa += 1.0 / std::sqrt(alpha * ModuleBase::PI);
498498

499499
div -= ModuleBase::e2 * ucell_omega * aa;
500-
exx_div = div * kv->get_nkstot_full() / nk_fac;
500+
exx_div = div * kv->get_nkstot_full();
501501
// exx_div = 0;
502-
// std::cout << "EXX divergence: " << exx_div << std::endl;
502+
// std::cout << "EXX divergence: " << exx_div << std::endl;
503503

504504
return exx_div;
505505
}

0 commit comments

Comments
 (0)