@@ -118,6 +118,7 @@ void OperatorEXXPW<T, Device>::act(const int nbands,
118118 if (p_exx_helper->construct_ace )
119119 {
120120 construct_ace ();
121+ std::cout << " ACE constructed" << std::endl;
121122 p_exx_helper->construct_ace = false ;
122123 }
123124
@@ -153,7 +154,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
153154
154155// set_psi(&p_exx_helper->psi);
155156
156- ModuleBase::timer::tick (" OperatorEXXPW" , " act " );
157+ ModuleBase::timer::tick (" OperatorEXXPW" , " act_op " );
157158
158159 setmem_complex_op ()(h_psi_recip, 0 , wfcpw->npwk_max );
159160 setmem_complex_op ()(h_psi_real, 0 , rhopw->nrxx );
@@ -257,7 +258,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
257258
258259 }
259260
260- ModuleBase::timer::tick (" OperatorEXXPW" , " act " );
261+ ModuleBase::timer::tick (" OperatorEXXPW" , " act_op " );
261262
262263}
263264
@@ -270,6 +271,8 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
270271 const int ngk_ik,
271272 const bool is_first_node) const
272273{
274+ ModuleBase::timer::tick (" OperatorEXXPW" , " act_op_ace" );
275+
273276// std::cout << "act_op_ace" << std::endl;
274277 // hpsi += -Xi^\dagger * Xi * psi
275278 auto Xi_ace = Xi_ace_k[this ->ik ];
@@ -325,6 +328,8 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
325328// vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1);
326329// delmem_complex_op()(hpsi);
327330 delmem_complex_op ()(Xi_psi);
331+ ModuleBase::timer::tick (" OperatorEXXPW" , " act_op" );
332+
328333}
329334
330335template <typename T, typename Device>
@@ -335,6 +340,9 @@ void OperatorEXXPW<T, Device>::construct_ace() const
335340 int nbasis = p_exx_helper->psi .get_nbasis ();
336341 int nk = p_exx_helper->psi .get_nk ();
337342
343+ int ik_store = this ->ik ;
344+ int *ik_ptr = const_cast <int *>(&this ->ik );
345+
338346 T intermediate_one = 1.0 , intermediate_zero = 0.0 ;
339347
340348 if (h_psi_ace == nullptr )
@@ -368,37 +376,53 @@ void OperatorEXXPW<T, Device>::construct_ace() const
368376 resmem_complex_op ()(psi_h_psi_ace, nbands * nbands);
369377 }
370378
371- // std::ofstream ofs_psi("psi.dat", std::ios::binary);
372- // p_exx_helper->psi.fix_kb(0, 0);
373- // ofs_psi.write(reinterpret_cast<char*>(p_exx_helper->psi.get_pointer()), nkb * nbasis * sizeof(T));
374- // ofs_psi.close();
375-
376- for (int ik = 0 ; ik < p_exx_helper->psi .get_nk (); ik++)
379+ for (int ik = 0 ; ik < nk; ik++)
377380 {
378- auto Xi_ace = Xi_ace_k [ik];
381+ int npwk = wfcpw-> npwk [ik];
379382
383+ T* Xi_ace = Xi_ace_k[ik];
380384 p_exx_helper->psi .fix_kb (ik, 0 );
381385 T* p_psi = p_exx_helper->psi .get_pointer ();
386+
387+ // if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
388+ // {
389+ // std::ofstream ofs_psi("psi.dat", std::ios::binary);
390+ // // p_exx_helper->psi.fix_kb(0, 0);
391+ // ofs_psi.write(reinterpret_cast<char*>(p_psi), nbands * nbasis * sizeof(T));
392+ // ofs_psi.close();
393+ // }
394+
395+ setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
396+
397+ *ik_ptr = ik;
398+
382399 act_op (
383- p_exx_helper-> psi . get_nbands () ,
400+ nbands ,
384401 nbasis,
385402 1 ,
386403 p_psi,
387404 h_psi_ace,
388- ik ,
405+ nbasis ,
389406 false
390407 );
391408
409+ // if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
410+ // {
411+ // std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary);
412+ // ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nbands * nbasis * sizeof(T));
413+ // ofs_hpsi.close();
414+ // }
415+
392416 // psi_h_psi_ace = psi^\dagger * h_psi_ace
393- p_exx_helper->psi .fix_kb (0 , 0 );
417+ // p_exx_helper->psi.fix_kb(0, 0);
394418 gemm_complex_op ()(this ->ctx ,
395419 ' C' ,
396420 ' N' ,
397421 nbands,
398422 nbands,
399- nbasis ,
423+ npwk ,
400424 &intermediate_one,
401- p_exx_helper-> psi . get_pointer () ,
425+ p_psi ,
402426 nbasis,
403427 h_psi_ace,
404428 nbasis,
@@ -407,17 +431,19 @@ void OperatorEXXPW<T, Device>::construct_ace() const
407431 nbands);
408432
409433 // reduction of psi_h_psi_ace, due to distributed memory
410- Parallel_Reduce::reduce_pool (psi_h_psi_ace, nbands * nbands);
411-
412- // // save psi_h_psi_ace to disk
434+ Parallel_Reduce::reduce_pool (psi_h_psi_ace, nbands * nbands);
435+
436+ // if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
437+ // { // save psi_h_psi_ace to disk
413438 // std::ofstream ofs_psi_hpsi("psihpsi.dat", std::ios::binary);
414- // ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nkb * nkb * sizeof(T));
439+ // ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nbands * nbands * sizeof(T));
415440 // ofs_psi_hpsi.close();
441+ // }
416442
417- // L_ace = cholesky(-psi_h_psi_ace)
418- #ifdef _OPENMP
419- #pragma omp parallel for schedule(static)
420- #endif
443+ // L_ace = cholesky(-psi_h_psi_ace)
444+ #ifdef _OPENMP
445+ #pragma omp parallel for schedule(static)
446+ #endif
421447 for (int i = 0 ; i < nbands; i++)
422448 {
423449 for (int j = 0 ; j < nbands; j++)
@@ -438,26 +464,29 @@ void OperatorEXXPW<T, Device>::construct_ace() const
438464 zpotrf_ (&lo, &nbands, L_ace, &nbands, &info);
439465 }
440466
441- // expand for-loop
442- #ifdef _OPENMP
443- #pragma omp parallel for schedule(static)
444- #endif
467+ // expand for-loop
468+ #ifdef _OPENMP
469+ #pragma omp parallel for schedule(static)
470+ #endif
445471 for (int i = 0 ; i < nbands; i++)
446472 {
447473 for (int j = 0 ; j < nbands; j++)
448474 {
449475 if (j < i)
450476 {
451- // L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]);
477+ // L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]);
452478 L_ace[i * nbands + j] = 0.0 ;
453479 }
454480 }
455481 }
456482
457- // // save L_ace to disk
458- // std::ofstream ofs_L("L.dat", std::ios::binary);
459- // ofs_L.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T));
460- // ofs_L.close();
483+ // // save L_ace to disk
484+ // if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
485+ // {
486+ // std::ofstream ofs_L("L.dat", std::ios::binary);
487+ // ofs_L.write(reinterpret_cast<char*>(L_ace), nbands * nbands * sizeof(T));
488+ // ofs_L.close();
489+ // }
461490
462491 // L_ace inv in place
463492 // T == std::complex<float> or std::complex<double>
@@ -474,17 +503,20 @@ void OperatorEXXPW<T, Device>::construct_ace() const
474503 ztrtri_ (&lo, &non_unitary, &nbands, L_ace, &nbands, &info);
475504 }
476505
477- // // save L_ace inv to disk
478- // std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary);
479- // ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T));
480- // ofs_L_inv.close();
506+ // // save L_ace inv to disk
507+ // if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
508+ // {
509+ // std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary);
510+ // ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nbands * nbands * sizeof(T));
511+ // ofs_L_inv.close();
512+ // }
481513
482514 // Xi_ace = L_ace^-1 * h_psi_ace^dagger
483515 gemm_complex_op ()(this ->ctx ,
484516 ' N' ,
485517 ' C' ,
486518 nbands,
487- nbasis ,
519+ npwk ,
488520 nbands,
489521 &intermediate_one,
490522 L_ace,
@@ -495,23 +527,31 @@ void OperatorEXXPW<T, Device>::construct_ace() const
495527 Xi_ace,
496528 nbands);
497529
498- // // save Xi_ace to disk
499- // std::ofstream ofs_Xi("Xi.dat", std::ios::binary);
500- // ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nkb * nbasis * sizeof(T));
501- // ofs_Xi.close();
530+ // save Xi_ace to disk
531+ // if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
532+ // {
533+ // std::ofstream ofs_Xi("Xi.dat", std::ios::binary);
534+ // ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nbands * nbasis * sizeof(T));
535+ // ofs_Xi.close();
536+ // }
502537 //
503- // std::cout << "nkb: " << nkb << std::endl;
504- // std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl;
505- // std::cout << "nbasis: " << nbasis << std::endl;
538+ // // std::cout << "nkb: " << nkb << std::endl;
539+ // if (ik == 1 && GlobalV::RANK_IN_POOL == 0)
540+ // {
541+ // std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl;
542+ // std::cout << "nbasis: " << nbasis << std::endl;
543+ // std::cout << "npwk: " << npwk << std::endl;
544+ // }
506545
507546 // clear mem
508547 setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
509- // setmem_complex_op()(Xi_ace, 0, nkb * nbasis);
510548 setmem_complex_op ()(psi_h_psi_ace, 0 , nbands * nbands);
511549 setmem_complex_op ()(L_ace, 0 , nbands * nbands);
512550
513551 }
514552
553+ *ik_ptr = ik_store;
554+
515555// // save h_psi_ace to disk
516556// std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary);
517557// ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nkb * nbasis * sizeof(T));
0 commit comments