Skip to content

Commit 9a31ef1

Browse files
committed
add k continuity in hsolver
1 parent 74284f9 commit 9a31ef1

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,17 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
280280
std::vector<Real> eigenvalues(this->wfc_basis->nks * psi.get_nbands(), 0.0);
281281
ethr_band.resize(psi.get_nbands(), this->diag_thr);
282282

283+
// Check if using k-point continuity
284+
use_k_continuity = (PARAM.init_wfc == "kcontinuity");
285+
if (use_k_continuity) {
286+
build_k_neighbors();
287+
}
288+
283289
/// Loop over k points for solve Hamiltonian to charge density
284-
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
290+
for (int i = 0; i < this->wfc_basis->nks; ++i)
285291
{
292+
const int ik = use_k_continuity ? k_order[i] : i;
293+
286294
/// update H(k) for each k point
287295
pHamilt->updateHk(ik);
288296

@@ -293,6 +301,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
293301
/// update psi pointer for each k point
294302
psi.fix_k(ik);
295303

304+
// If using k-point continuity and not first k-point, propagate from parent
305+
if (use_k_continuity && i > 0) {
306+
propagate_psi(psi, k_parent[ik], ik);
307+
}
308+
296309
// template add precondition calculating here
297310
update_precondition(precondition, ik, this->wfc_basis->npwk[ik], Real(pes->pot->get_vl_of_0()));
298311

@@ -663,6 +676,74 @@ void HSolverPW<T, Device>::output_iterInfo()
663676
}
664677
}
665678

679+
template <typename T, typename Device>
680+
void HSolverPW<T, Device>::build_k_neighbors() {
681+
const int nk = this->wfc_basis->nks;
682+
kvecs_c.resize(nk);
683+
k_order.reserve(nk);
684+
685+
// Build k-point list
686+
std::vector<std::pair<ModuleBase::Vector3<double>, int>> klist;
687+
for (int ik = 0; ik < nk; ++ik) {
688+
kvecs_c[ik] = this->pes->klist->kvec_c[ik];
689+
klist.emplace_back(kvecs_c[ik], ik);
690+
}
691+
692+
// Sort k-points by distance from origin
693+
std::sort(klist.begin(), klist.end(),
694+
[](const auto& a, const auto& b) {
695+
return a.first.norm() < b.first.norm();
696+
});
697+
698+
// Build parent-child relationships
699+
k_order.push_back(klist[0].second);
700+
701+
for (size_t i = 1; i < klist.size(); ++i) {
702+
int ik = klist[i].second;
703+
double min_dist = 1e10;
704+
int parent = -1;
705+
706+
for (int jk : k_order) {
707+
double dist = (kvecs_c[ik] - kvecs_c[jk]).norm2();
708+
if (dist < min_dist) {
709+
min_dist = dist;
710+
parent = jk;
711+
}
712+
}
713+
714+
k_parent[ik] = parent;
715+
k_order.push_back(ik);
716+
}
717+
}
718+
719+
template <typename T, typename Device>
720+
void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, const int to_ik) {
721+
const int nbands = psi.get_nbands();
722+
const int npwk = this->wfc_basis->npwk[to_ik];
723+
724+
// Get k-point difference
725+
ModuleBase::Vector3<double> dk = kvecs_c[to_ik] - kvecs_c[from_ik];
726+
727+
// Allocate temporary arrays
728+
std::vector<T> psi_real(this->wfc_basis->nrxx);
729+
730+
// Process each band
731+
for (int ib = 0; ib < nbands; ++ib) {
732+
// IFFT to real space
733+
this->wfc_basis->recip2real(psi.get_pointer(from_ik, ib), psi_real.data(), from_ik);
734+
735+
// Apply phase factor
736+
for (int ir = 0; ir < this->wfc_basis->nrxx; ++ir) {
737+
ModuleBase::Vector3<double> r = this->wfc_basis->get_ir2r(ir);
738+
double phase = this->wfc_basis->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z);
739+
psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
740+
}
741+
742+
// FFT back to reciprocal space
743+
this->wfc_basis->real2recip(psi_real.data(), psi.get_pointer(to_ik, ib), to_ik);
744+
}
745+
}
746+
666747
template class HSolverPW<std::complex<float>, base_device::DEVICE_CPU>;
667748
template class HSolverPW<std::complex<double>, base_device::DEVICE_CPU>;
668749
#if ((defined __CUDA) || (defined __ROCM))

source/module_hsolver/hsolver_pw.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ class HSolverPW
9898

9999
void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes,const double tpiba,const int nat);
100100
#endif
101+
102+
// K-point continuity related members
103+
bool use_k_continuity = false;
104+
std::vector<int> k_order;
105+
std::unordered_map<int, int> k_parent;
106+
std::vector<ModuleBase::Vector3<double>> kvecs_c;
107+
108+
void build_k_neighbors();
109+
void propagate_psi(psi::Psi<T>& psi, const int from_ik, const int to_ik);
101110
};
102111

103112
} // namespace hsolver

0 commit comments

Comments
 (0)