Skip to content

Commit e330b57

Browse files
committed
fix device
1 parent 98fec60 commit e330b57

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ void HSolverPW<T, Device>::build_k_neighbors() {
689689
k_order.clear();
690690
k_order.reserve(nk);
691691

692-
// 存储k点和对应索引的结构体
692+
// Store k-points and corresponding indices
693693
struct KPoint {
694694
ModuleBase::Vector3<double> kvec;
695695
int index;
@@ -699,29 +699,29 @@ void HSolverPW<T, Device>::build_k_neighbors() {
699699
kvec(v), index(i), norm(v.norm()) {}
700700
};
701701

702-
// 构建k点列表
702+
// Build k-point list
703703
std::vector<KPoint> klist;
704704
for (int ik = 0; ik < nk; ++ik) {
705705
kvecs_c[ik] = this->wfc_basis->kvec_c[ik];
706706
klist.push_back(KPoint(kvecs_c[ik], ik));
707707
}
708708

709-
// 按照到原点距离排序k点
709+
// Sort k-points by distance from origin
710710
std::sort(klist.begin(), klist.end(),
711711
[](const KPoint& a, const KPoint& b) {
712712
return a.norm < b.norm;
713713
});
714714

715-
// 构建父子关系
715+
// Build parent-child relationships
716716
k_order.push_back(klist[0].index);
717717

718-
// 对每个k点找最近的已处理k点作为父节点
718+
// Find nearest processed k-point as parent for each k-point
719719
for (int i = 1; i < nk; ++i) {
720720
int current_k = klist[i].index;
721721
double min_dist = 1e10;
722722
int parent = -1;
723723

724-
// 在已处理的k点中找最近邻
724+
// find the nearest k-point as parent
725725
for (int j = 0; j < k_order.size(); ++j) {
726726
int processed_k = k_order[j];
727727
double dist = (kvecs_c[current_k] - kvecs_c[processed_k]).norm2();
@@ -744,14 +744,14 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
744744
// Get k-point difference
745745
ModuleBase::Vector3<double> dk = kvecs_c[to_ik] - kvecs_c[from_ik];
746746

747-
// Allocate temporary arrays
748-
std::vector<T> psi_real(this->wfc_basis->nrxx);
747+
// Allocate temporary arrays using device-aware memory management
748+
T* porter = nullptr;
749+
resmem_complex_op()(this->ctx, porter, this->wfc_basis->nmaxgr, "HSolverPW::porter");
749750

750751
// Process each band
751752
for (int ib = 0; ib < nbands; ++ib) {
752-
// IFFT to real space
753-
// TODO: Check if the call is correct
754-
this->wfc_basis->recip_to_real(this->ctx, &psi(from_ik, ib, 0), psi_real.data(), from_ik);
753+
// FFT to real space
754+
this->wfc_basis->recip_to_real(this->ctx, &psi(from_ik, ib, 0), porter, from_ik);
755755

756756
// Apply phase factor
757757
// // TODO: Check how to get the r vector
@@ -761,10 +761,11 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
761761
// }
762762

763763
// FFT back to reciprocal space
764-
// TODO: Check if the call is correct
765-
766-
this->wfc_basis->real_to_recip(this->ctx, psi_real.data(), psi.get_pointer(ib), to_ik);
764+
this->wfc_basis->real_to_recip(this->ctx, porter, &psi(to_ik, ib, 0), to_ik, true);
767765
}
766+
767+
// Clean up temporary arrays
768+
delmem_complex_op()(this->ctx, porter);
768769
}
769770

770771
template class HSolverPW<std::complex<float>, base_device::DEVICE_CPU>;

0 commit comments

Comments
 (0)