@@ -280,9 +280,17 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
280280 std::vector<Real> eigenvalues (this ->wfc_basis ->nks * psi.get_nbands (), 0.0 );
281281 ethr_band.resize (psi.get_nbands (), this ->diag_thr );
282282
283+ // Check if using k-point continuity
284+ use_k_continuity = (PARAM.init_wfc == " kcontinuity" );
285+ if (use_k_continuity) {
286+ build_k_neighbors ();
287+ }
288+
283289 // / Loop over k points for solve Hamiltonian to charge density
284- for (int ik = 0 ; ik < this ->wfc_basis ->nks ; ++ik )
290+ for (int i = 0 ; i < this ->wfc_basis ->nks ; ++i )
285291 {
292+ const int ik = use_k_continuity ? k_order[i] : i;
293+
286294 // / update H(k) for each k point
287295 pHamilt->updateHk (ik);
288296
@@ -293,6 +301,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
293301 // / update psi pointer for each k point
294302 psi.fix_k (ik);
295303
304+ // If using k-point continuity and not first k-point, propagate from parent
305+ if (use_k_continuity && i > 0 ) {
306+ propagate_psi (psi, k_parent[ik], ik);
307+ }
308+
296309 // template add precondition calculating here
297310 update_precondition (precondition, ik, this ->wfc_basis ->npwk [ik], Real (pes->pot ->get_vl_of_0 ()));
298311
@@ -663,6 +676,74 @@ void HSolverPW<T, Device>::output_iterInfo()
663676 }
664677}
665678
679+ template <typename T, typename Device>
680+ void HSolverPW<T, Device>::build_k_neighbors() {
681+ const int nk = this ->wfc_basis ->nks ;
682+ kvecs_c.resize (nk);
683+ k_order.reserve (nk);
684+
685+ // Build k-point list
686+ std::vector<std::pair<ModuleBase::Vector3<double >, int >> klist;
687+ for (int ik = 0 ; ik < nk; ++ik) {
688+ kvecs_c[ik] = this ->pes ->klist ->kvec_c [ik];
689+ klist.emplace_back (kvecs_c[ik], ik);
690+ }
691+
692+ // Sort k-points by distance from origin
693+ std::sort (klist.begin (), klist.end (),
694+ [](const auto & a, const auto & b) {
695+ return a.first .norm () < b.first .norm ();
696+ });
697+
698+ // Build parent-child relationships
699+ k_order.push_back (klist[0 ].second );
700+
701+ for (size_t i = 1 ; i < klist.size (); ++i) {
702+ int ik = klist[i].second ;
703+ double min_dist = 1e10 ;
704+ int parent = -1 ;
705+
706+ for (int jk : k_order) {
707+ double dist = (kvecs_c[ik] - kvecs_c[jk]).norm2 ();
708+ if (dist < min_dist) {
709+ min_dist = dist;
710+ parent = jk;
711+ }
712+ }
713+
714+ k_parent[ik] = parent;
715+ k_order.push_back (ik);
716+ }
717+ }
718+
719+ template <typename T, typename Device>
720+ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, const int to_ik) {
721+ const int nbands = psi.get_nbands ();
722+ const int npwk = this ->wfc_basis ->npwk [to_ik];
723+
724+ // Get k-point difference
725+ ModuleBase::Vector3<double > dk = kvecs_c[to_ik] - kvecs_c[from_ik];
726+
727+ // Allocate temporary arrays
728+ std::vector<T> psi_real (this ->wfc_basis ->nrxx );
729+
730+ // Process each band
731+ for (int ib = 0 ; ib < nbands; ++ib) {
732+ // IFFT to real space
733+ this ->wfc_basis ->recip2real (psi.get_pointer (from_ik, ib), psi_real.data (), from_ik);
734+
735+ // Apply phase factor
736+ for (int ir = 0 ; ir < this ->wfc_basis ->nrxx ; ++ir) {
737+ ModuleBase::Vector3<double > r = this ->wfc_basis ->get_ir2r (ir);
738+ double phase = this ->wfc_basis ->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z );
739+ psi_real[ir] *= std::exp (std::complex <double >(0.0 , phase));
740+ }
741+
742+ // FFT back to reciprocal space
743+ this ->wfc_basis ->real2recip (psi_real.data (), psi.get_pointer (to_ik, ib), to_ik);
744+ }
745+ }
746+
666747template class HSolverPW <std::complex <float >, base_device::DEVICE_CPU>;
667748template class HSolverPW <std::complex <double >, base_device::DEVICE_CPU>;
668749#if ((defined __CUDA) || (defined __ROCM))
0 commit comments