Skip to content

Commit d23bb1d

Browse files
committed
fix FFT call
1 parent fe8e0c0 commit d23bb1d

File tree

2 files changed

+50
-30
lines changed

2 files changed

+50
-30
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -280,17 +280,18 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
280280
std::vector<Real> eigenvalues(this->wfc_basis->nks * psi.get_nbands(), 0.0);
281281
ethr_band.resize(psi.get_nbands(), this->diag_thr);
282282

283-
// Check if using k-point continuity
284-
use_k_continuity = (PARAM.init_wfc == "kcontinuity");
283+
// using k-point continuity
285284
if (use_k_continuity) {
286285
build_k_neighbors();
287286
}
288287

288+
static int count = 0;
289+
289290
/// Loop over k points for solve Hamiltonian to charge density
290291
for (int i = 0; i < this->wfc_basis->nks; ++i)
291292
{
292-
const int ik = use_k_continuity ? k_order[i] : i;
293-
293+
const int ik = use_k_continuity ? k_order[i] : i;
294+
ModuleBase::timer::tick("HsolverPW", "k_point: " + std::to_string(ik));
294295
/// update H(k) for each k point
295296
pHamilt->updateHk(ik);
296297

@@ -302,7 +303,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
302303
psi.fix_k(ik);
303304

304305
// If using k-point continuity and not first k-point, propagate from parent
305-
if (use_k_continuity && i > 0) {
306+
if (use_k_continuity && ik > 0 && count == 0) {
306307
propagate_psi(psi, k_parent[ik], ik);
307308
}
308309

@@ -336,8 +337,10 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
336337
<< " ; where current threshold is: " << this->diag_thr << " . " << std::endl;
337338
DiagoIterAssist<T, Device>::avg_iter = 0.0;
338339
}
340+
ModuleBase::timer::tick("HsolverPW", "k_point: " + std::to_string(ik));
339341
/// calculate the contribution of Psi for charge density rho
340342
}
343+
count++;
341344
// END Loop over k points
342345

343346
// copy eigenvalues to ekb in ElecState
@@ -680,39 +683,53 @@ template <typename T, typename Device>
680683
void HSolverPW<T, Device>::build_k_neighbors() {
681684
const int nk = this->wfc_basis->nks;
682685
kvecs_c.resize(nk);
686+
k_order.clear();
683687
k_order.reserve(nk);
684688

685-
// Build k-point list
686-
std::vector<std::pair<ModuleBase::Vector3<double>, int>> klist;
689+
// 存储k点和对应索引的结构体
690+
struct KPoint {
691+
ModuleBase::Vector3<double> kvec;
692+
int index;
693+
double norm;
694+
695+
KPoint(const ModuleBase::Vector3<double>& v, int i) :
696+
kvec(v), index(i), norm(v.norm()) {}
697+
};
698+
699+
// 构建k点列表
700+
std::vector<KPoint> klist;
687701
for (int ik = 0; ik < nk; ++ik) {
688-
kvecs_c[ik] = this->pes->klist->kvec_c[ik];
689-
klist.emplace_back(kvecs_c[ik], ik);
702+
kvecs_c[ik] = this->wfc_basis->kvec_c[ik];
703+
klist.push_back(KPoint(kvecs_c[ik], ik));
690704
}
691705

692-
// Sort k-points by distance from origin
706+
// 按照到原点距离排序k点
693707
std::sort(klist.begin(), klist.end(),
694-
[](const auto& a, const auto& b) {
695-
return a.first.norm() < b.first.norm();
708+
[](const KPoint& a, const KPoint& b) {
709+
return a.norm < b.norm;
696710
});
697711

698-
// Build parent-child relationships
699-
k_order.push_back(klist[0].second);
712+
// 构建父子关系
713+
k_order.push_back(klist[0].index);
700714

701-
for (size_t i = 1; i < klist.size(); ++i) {
702-
int ik = klist[i].second;
715+
// 对每个k点找最近的已处理k点作为父节点
716+
for (int i = 1; i < nk; ++i) {
717+
int current_k = klist[i].index;
703718
double min_dist = 1e10;
704719
int parent = -1;
705720

706-
for (int jk : k_order) {
707-
double dist = (kvecs_c[ik] - kvecs_c[jk]).norm2();
721+
// 在已处理的k点中找最近邻
722+
for (int j = 0; j < k_order.size(); ++j) {
723+
int processed_k = k_order[j];
724+
double dist = (kvecs_c[current_k] - kvecs_c[processed_k]).norm2();
708725
if (dist < min_dist) {
709726
min_dist = dist;
710-
parent = jk;
727+
parent = processed_k;
711728
}
712729
}
713730

714-
k_parent[ik] = parent;
715-
k_order.push_back(ik);
731+
k_parent[current_k] = parent;
732+
k_order.push_back(current_k);
716733
}
717734
}
718735

@@ -730,17 +747,20 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
730747
// Process each band
731748
for (int ib = 0; ib < nbands; ++ib) {
732749
// IFFT to real space
733-
this->wfc_basis->recip2real(psi.get_pointer(from_ik, ib), psi_real.data(), from_ik);
750+
// TODO: Check if the call is correct
751+
this->wfc_basis->recip_to_real(this->ctx, &psi(from_ik, ib, 0), psi_real.data(), from_ik);
734752

735753
// Apply phase factor
736-
for (int ir = 0; ir < this->wfc_basis->nrxx; ++ir) {
737-
ModuleBase::Vector3<double> r = this->wfc_basis->get_ir2r(ir);
738-
double phase = this->wfc_basis->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z);
739-
psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
740-
}
754+
// // TODO: Check how to get the r vector
755+
// ModuleBase::Vector3<double> r = this->wfc_basis->get_ir2r(ir);
756+
// double phase = this->wfc_basis->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z);
757+
// psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
758+
// }
741759

742760
// FFT back to reciprocal space
743-
this->wfc_basis->real2recip(psi_real.data(), psi.get_pointer(to_ik, ib), to_ik);
761+
// TODO: Check if the call is correct
762+
763+
this->wfc_basis->real_to_recip(this->ctx, psi_real.data(), psi.get_pointer(ib), to_ik);
744764
}
745765
}
746766

@@ -751,4 +771,4 @@ template class HSolverPW<std::complex<float>, base_device::DEVICE_GPU>;
751771
template class HSolverPW<std::complex<double>, base_device::DEVICE_GPU>;
752772
#endif
753773

754-
} // namespace hsolver
774+
} // namespace hsolver

source/module_hsolver/hsolver_pw.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class HSolverPW
101101
#endif
102102

103103
// K-point continuity related members
104-
bool use_k_continuity = false;
104+
bool use_k_continuity = true;
105105
std::vector<int> k_order;
106106
std::unordered_map<int, int> k_parent;
107107
std::vector<ModuleBase::Vector3<double>> kvecs_c;

0 commit comments

Comments
 (0)