Skip to content

Commit 1c6c47e

Browse files
Feature: ACE works. Next step is ACE energy.
1 parent 08d92dd commit 1c6c47e

File tree

2 files changed

+87
-46
lines changed

2 files changed

+87
-46
lines changed

source/module_esolver/module_exx_helper/exx_helper.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
template <typename T, typename Device>
44
double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::Psi<T, Device>& psi, ESolver_KS_PW<T, Device>* this_)
55
{
6+
ModuleBase::timer::tick("ESolver_KS_PW", "cal_exx_energy");
7+
68
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
79
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
810
T* psi_nk_real = new T[this_->pw_wfc->nrxx];
@@ -107,7 +109,6 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
107109
double exx_div = div;
108110

109111
if (wf_wg == nullptr) return 0.0;
110-
ModuleBase::timer::tick("OperatorEXXPW", "get_Eexx");
111112
// evaluate the Eexx
112113
// T Eexx_ik = 0.0;
113114
Real Eexx_ik_real = 0.0;
@@ -240,7 +241,7 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
240241
// std::cout << "omega = " << this_->pelec->omega << " tpiba = " << this_->pw_rho->tpiba2 << " exx_div = " << exx_div << std::endl;
241242

242243
Real Eexx = Eexx_ik_real;
243-
ModuleBase::timer::tick("OperatorEXXPW", "get_Eexx");
244+
ModuleBase::timer::tick("ESolver_KS_PW", "cal_exx_energy");
244245
return Eexx;
245246
}
246247

source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ void OperatorEXXPW<T, Device>::act(const int nbands,
118118
if (p_exx_helper->construct_ace)
119119
{
120120
construct_ace();
121+
std::cout << "ACE constructed" << std::endl;
121122
p_exx_helper->construct_ace = false;
122123
}
123124

@@ -153,7 +154,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
153154

154155
// set_psi(&p_exx_helper->psi);
155156

156-
ModuleBase::timer::tick("OperatorEXXPW", "act");
157+
ModuleBase::timer::tick("OperatorEXXPW", "act_op");
157158

158159
setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max);
159160
setmem_complex_op()(h_psi_real, 0, rhopw->nrxx);
@@ -257,7 +258,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
257258

258259
}
259260

260-
ModuleBase::timer::tick("OperatorEXXPW", "act");
261+
ModuleBase::timer::tick("OperatorEXXPW", "act_op");
261262

262263
}
263264

@@ -270,6 +271,8 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
270271
const int ngk_ik,
271272
const bool is_first_node) const
272273
{
274+
ModuleBase::timer::tick("OperatorEXXPW", "act_op_ace");
275+
273276
// std::cout << "act_op_ace" << std::endl;
274277
// hpsi += -Xi^\dagger * Xi * psi
275278
auto Xi_ace = Xi_ace_k[this->ik];
@@ -325,6 +328,8 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
325328
// vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1);
326329
// delmem_complex_op()(hpsi);
327330
delmem_complex_op()(Xi_psi);
331+
ModuleBase::timer::tick("OperatorEXXPW", "act_op");
332+
328333
}
329334

330335
template <typename T, typename Device>
@@ -335,6 +340,9 @@ void OperatorEXXPW<T, Device>::construct_ace() const
335340
int nbasis = p_exx_helper->psi.get_nbasis();
336341
int nk = p_exx_helper->psi.get_nk();
337342

343+
int ik_store = this->ik;
344+
int *ik_ptr = const_cast<int*>(&this->ik);
345+
338346
T intermediate_one = 1.0, intermediate_zero = 0.0;
339347

340348
if (h_psi_ace == nullptr)
@@ -368,37 +376,53 @@ void OperatorEXXPW<T, Device>::construct_ace() const
368376
resmem_complex_op()(psi_h_psi_ace, nbands * nbands);
369377
}
370378

371-
// std::ofstream ofs_psi("psi.dat", std::ios::binary);
372-
// p_exx_helper->psi.fix_kb(0, 0);
373-
// ofs_psi.write(reinterpret_cast<char*>(p_exx_helper->psi.get_pointer()), nkb * nbasis * sizeof(T));
374-
// ofs_psi.close();
375-
376-
for (int ik = 0; ik < p_exx_helper->psi.get_nk(); ik++)
379+
for (int ik = 0; ik < nk; ik++)
377380
{
378-
auto Xi_ace = Xi_ace_k[ik];
381+
int npwk = wfcpw->npwk[ik];
379382

383+
T* Xi_ace = Xi_ace_k[ik];
380384
p_exx_helper->psi.fix_kb(ik, 0);
381385
T* p_psi = p_exx_helper->psi.get_pointer();
386+
387+
// if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
388+
// {
389+
// std::ofstream ofs_psi("psi.dat", std::ios::binary);
390+
// // p_exx_helper->psi.fix_kb(0, 0);
391+
// ofs_psi.write(reinterpret_cast<char*>(p_psi), nbands * nbasis * sizeof(T));
392+
// ofs_psi.close();
393+
// }
394+
395+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
396+
397+
*ik_ptr = ik;
398+
382399
act_op(
383-
p_exx_helper->psi.get_nbands(),
400+
nbands,
384401
nbasis,
385402
1,
386403
p_psi,
387404
h_psi_ace,
388-
ik,
405+
nbasis,
389406
false
390407
);
391408

409+
// if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
410+
// {
411+
// std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary);
412+
// ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nbands * nbasis * sizeof(T));
413+
// ofs_hpsi.close();
414+
// }
415+
392416
// psi_h_psi_ace = psi^\dagger * h_psi_ace
393-
p_exx_helper->psi.fix_kb(0, 0);
417+
// p_exx_helper->psi.fix_kb(0, 0);
394418
gemm_complex_op()(this->ctx,
395419
'C',
396420
'N',
397421
nbands,
398422
nbands,
399-
nbasis,
423+
npwk,
400424
&intermediate_one,
401-
p_exx_helper->psi.get_pointer(),
425+
p_psi,
402426
nbasis,
403427
h_psi_ace,
404428
nbasis,
@@ -407,17 +431,19 @@ void OperatorEXXPW<T, Device>::construct_ace() const
407431
nbands);
408432

409433
// reduction of psi_h_psi_ace, due to distributed memory
410-
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands);
411-
412-
// // save psi_h_psi_ace to disk
434+
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands);
435+
436+
// if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
437+
// { // save psi_h_psi_ace to disk
413438
// std::ofstream ofs_psi_hpsi("psihpsi.dat", std::ios::binary);
414-
// ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nkb * nkb * sizeof(T));
439+
// ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nbands * nbands * sizeof(T));
415440
// ofs_psi_hpsi.close();
441+
// }
416442

417-
// L_ace = cholesky(-psi_h_psi_ace)
418-
#ifdef _OPENMP
419-
#pragma omp parallel for schedule(static)
420-
#endif
443+
// L_ace = cholesky(-psi_h_psi_ace)
444+
#ifdef _OPENMP
445+
#pragma omp parallel for schedule(static)
446+
#endif
421447
for (int i = 0; i < nbands; i++)
422448
{
423449
for (int j = 0; j < nbands; j++)
@@ -438,26 +464,29 @@ void OperatorEXXPW<T, Device>::construct_ace() const
438464
zpotrf_(&lo, &nbands, L_ace, &nbands, &info);
439465
}
440466

441-
// expand for-loop
442-
#ifdef _OPENMP
443-
#pragma omp parallel for schedule(static)
444-
#endif
467+
// expand for-loop
468+
#ifdef _OPENMP
469+
#pragma omp parallel for schedule(static)
470+
#endif
445471
for (int i = 0; i < nbands; i++)
446472
{
447473
for (int j = 0; j < nbands; j++)
448474
{
449475
if (j < i)
450476
{
451-
// L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]);
477+
// L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]);
452478
L_ace[i * nbands + j] = 0.0;
453479
}
454480
}
455481
}
456482

457-
// // save L_ace to disk
458-
// std::ofstream ofs_L("L.dat", std::ios::binary);
459-
// ofs_L.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T));
460-
// ofs_L.close();
483+
// // save L_ace to disk
484+
// if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
485+
// {
486+
// std::ofstream ofs_L("L.dat", std::ios::binary);
487+
// ofs_L.write(reinterpret_cast<char*>(L_ace), nbands * nbands * sizeof(T));
488+
// ofs_L.close();
489+
// }
461490

462491
// L_ace inv in place
463492
// T == std::complex<float> or std::complex<double>
@@ -474,17 +503,20 @@ void OperatorEXXPW<T, Device>::construct_ace() const
474503
ztrtri_(&lo, &non_unitary, &nbands, L_ace, &nbands, &info);
475504
}
476505

477-
// // save L_ace inv to disk
478-
// std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary);
479-
// ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T));
480-
// ofs_L_inv.close();
506+
// // save L_ace inv to disk
507+
// if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
508+
// {
509+
// std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary);
510+
// ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nbands * nbands * sizeof(T));
511+
// ofs_L_inv.close();
512+
// }
481513

482514
// Xi_ace = L_ace^-1 * h_psi_ace^dagger
483515
gemm_complex_op()(this->ctx,
484516
'N',
485517
'C',
486518
nbands,
487-
nbasis,
519+
npwk,
488520
nbands,
489521
&intermediate_one,
490522
L_ace,
@@ -495,23 +527,31 @@ void OperatorEXXPW<T, Device>::construct_ace() const
495527
Xi_ace,
496528
nbands);
497529

498-
// // save Xi_ace to disk
499-
// std::ofstream ofs_Xi("Xi.dat", std::ios::binary);
500-
// ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nkb * nbasis * sizeof(T));
501-
// ofs_Xi.close();
530+
// save Xi_ace to disk
531+
// if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
532+
// {
533+
// std::ofstream ofs_Xi("Xi.dat", std::ios::binary);
534+
// ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nbands * nbasis * sizeof(T));
535+
// ofs_Xi.close();
536+
// }
502537
//
503-
// std::cout << "nkb: " << nkb << std::endl;
504-
// std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl;
505-
// std::cout << "nbasis: " << nbasis << std::endl;
538+
// // std::cout << "nkb: " << nkb << std::endl;
539+
// if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
540+
// {
541+
// std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl;
542+
// std::cout << "nbasis: " << nbasis << std::endl;
543+
// std::cout << "npwk: " << npwk << std::endl;
544+
// }
506545

507546
// clear mem
508547
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
509-
// setmem_complex_op()(Xi_ace, 0, nkb * nbasis);
510548
setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands);
511549
setmem_complex_op()(L_ace, 0, nbands * nbands);
512550

513551
}
514552

553+
*ik_ptr = ik_store;
554+
515555
// // save h_psi_ace to disk
516556
// std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary);
517557
// ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nkb * nbasis * sizeof(T));

0 commit comments

Comments
 (0)