@@ -280,17 +280,18 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
280280 std::vector<Real> eigenvalues (this ->wfc_basis ->nks * psi.get_nbands (), 0.0 );
281281 ethr_band.resize (psi.get_nbands (), this ->diag_thr );
282282
283- // Check if using k-point continuity
284- use_k_continuity = (PARAM.init_wfc == " kcontinuity" );
283+ // using k-point continuity
285284 if (use_k_continuity) {
286285 build_k_neighbors ();
287286 }
288287
288+ static int count = 0 ;
289+
289290 // / Loop over k points for solve Hamiltonian to charge density
290291 for (int i = 0 ; i < this ->wfc_basis ->nks ; ++i)
291292 {
292- const int ik = use_k_continuity ? k_order[i] : i;
293-
293+ const int ik = use_k_continuity ? k_order[i] : i;
294+ ModuleBase::timer::tick ( " HsolverPW " , " k_point: " + std::to_string (ik));
294295 // / update H(k) for each k point
295296 pHamilt->updateHk (ik);
296297
@@ -302,7 +303,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
302303 psi.fix_k (ik);
303304
304305 // If using k-point continuity and not first k-point, propagate from parent
305- if (use_k_continuity && i > 0 ) {
306+ if (use_k_continuity && ik > 0 && count == 0 ) {
306307 propagate_psi (psi, k_parent[ik], ik);
307308 }
308309
@@ -336,8 +337,10 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
336337 << " ; where current threshold is: " << this ->diag_thr << " . " << std::endl;
337338 DiagoIterAssist<T, Device>::avg_iter = 0.0 ;
338339 }
340+ ModuleBase::timer::tick (" HsolverPW" , " k_point: " + std::to_string (ik));
339341 // / calculate the contribution of Psi for charge density rho
340342 }
343+ count++;
341344 // END Loop over k points
342345
343346 // copy eigenvalues to ekb in ElecState
@@ -680,39 +683,53 @@ template <typename T, typename Device>
680683void HSolverPW<T, Device>::build_k_neighbors() {
681684 const int nk = this ->wfc_basis ->nks ;
682685 kvecs_c.resize (nk);
686+ k_order.clear ();
683687 k_order.reserve (nk);
684688
685- // Build k-point list
686- std::vector<std::pair<ModuleBase::Vector3<double >, int >> klist;
689+ // 存储k点和对应索引的结构体
690+ struct KPoint {
691+ ModuleBase::Vector3<double > kvec;
692+ int index;
693+ double norm;
694+
695+ KPoint (const ModuleBase::Vector3<double >& v, int i) :
696+ kvec (v), index(i), norm(v.norm()) {}
697+ };
698+
699+ // 构建k点列表
700+ std::vector<KPoint> klist;
687701 for (int ik = 0 ; ik < nk; ++ik) {
688- kvecs_c[ik] = this ->pes -> klist ->kvec_c [ik];
689- klist.emplace_back ( kvecs_c[ik], ik);
702+ kvecs_c[ik] = this ->wfc_basis ->kvec_c [ik];
703+ klist.push_back ( KPoint ( kvecs_c[ik], ik) );
690704 }
691705
692- // Sort k-points by distance from origin
706+ // 按照到原点距离排序k点
693707 std::sort (klist.begin (), klist.end (),
694- [](const auto & a, const auto & b) {
695- return a.first . norm () < b.first . norm () ;
708+ [](const KPoint & a, const KPoint & b) {
709+ return a.norm < b.norm ;
696710 });
697711
698- // Build parent-child relationships
699- k_order.push_back (klist[0 ].second );
712+ // 构建父子关系
713+ k_order.push_back (klist[0 ].index );
700714
701- for (size_t i = 1 ; i < klist.size (); ++i) {
702- int ik = klist[i].second ;
715+ // 对每个k点找最近的已处理k点作为父节点
716+ for (int i = 1 ; i < nk; ++i) {
717+ int current_k = klist[i].index ;
703718 double min_dist = 1e10 ;
704719 int parent = -1 ;
705720
706- for (int jk : k_order) {
707- double dist = (kvecs_c[ik] - kvecs_c[jk]).norm2 ();
721+ // 在已处理的k点中找最近邻
722+ for (int j = 0 ; j < k_order.size (); ++j) {
723+ int processed_k = k_order[j];
724+ double dist = (kvecs_c[current_k] - kvecs_c[processed_k]).norm2 ();
708725 if (dist < min_dist) {
709726 min_dist = dist;
710- parent = jk ;
727+ parent = processed_k ;
711728 }
712729 }
713730
714- k_parent[ik ] = parent;
715- k_order.push_back (ik );
731+ k_parent[current_k ] = parent;
732+ k_order.push_back (current_k );
716733 }
717734}
718735
@@ -730,17 +747,20 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
730747 // Process each band
731748 for (int ib = 0 ; ib < nbands; ++ib) {
732749 // IFFT to real space
733- this ->wfc_basis ->recip2real (psi.get_pointer (from_ik, ib), psi_real.data (), from_ik);
750+ // TODO: Check if the call is correct
751+ this ->wfc_basis ->recip_to_real (this ->ctx , &psi (from_ik, ib, 0 ), psi_real.data (), from_ik);
734752
735753 // Apply phase factor
736- for ( int ir = 0 ; ir < this -> wfc_basis -> nrxx ; ++ir) {
737- ModuleBase::Vector3<double > r = this ->wfc_basis ->get_ir2r (ir);
738- double phase = this ->wfc_basis ->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z );
739- psi_real[ir] *= std::exp (std::complex <double >(0.0 , phase));
740- }
754+ // // TODO: Check how to get the r vector
755+ // ModuleBase::Vector3<double> r = this->wfc_basis->get_ir2r(ir);
756+ // double phase = this->wfc_basis->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z);
757+ // psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
758+ // }
741759
742760 // FFT back to reciprocal space
743- this ->wfc_basis ->real2recip (psi_real.data (), psi.get_pointer (to_ik, ib), to_ik);
761+ // TODO: Check if the call is correct
762+
763+ this ->wfc_basis ->real_to_recip (this ->ctx , psi_real.data (), psi.get_pointer (ib), to_ik);
744764 }
745765}
746766
@@ -751,4 +771,4 @@ template class HSolverPW<std::complex<float>, base_device::DEVICE_GPU>;
751771template class HSolverPW <std::complex <double >, base_device::DEVICE_GPU>;
752772#endif
753773
754- } // namespace hsolver
774+ } // namespace hsolver
0 commit comments