@@ -689,7 +689,7 @@ void HSolverPW<T, Device>::build_k_neighbors() {
689689 k_order.clear ();
690690 k_order.reserve (nk);
691691
692- // 存储k点和对应索引的结构体
692+ // Store k-points and corresponding indices
693693 struct KPoint {
694694 ModuleBase::Vector3<double > kvec;
695695 int index;
@@ -699,29 +699,29 @@ void HSolverPW<T, Device>::build_k_neighbors() {
699699 kvec (v), index(i), norm(v.norm()) {}
700700 };
701701
702- // 构建k点列表
702+ // Build k-point list
703703 std::vector<KPoint> klist;
704704 for (int ik = 0 ; ik < nk; ++ik) {
705705 kvecs_c[ik] = this ->wfc_basis ->kvec_c [ik];
706706 klist.push_back (KPoint (kvecs_c[ik], ik));
707707 }
708708
709- // 按照到原点距离排序k点
709+ // Sort k-points by distance from origin
710710 std::sort (klist.begin (), klist.end (),
711711 [](const KPoint& a, const KPoint& b) {
712712 return a.norm < b.norm ;
713713 });
714714
715- // 构建父子关系
715+ // Build parent-child relationships
716716 k_order.push_back (klist[0 ].index );
717717
718- // 对每个k点找最近的已处理k点作为父节点
718+ // Find nearest processed k-point as parent for each k-point
719719 for (int i = 1 ; i < nk; ++i) {
720720 int current_k = klist[i].index ;
721721 double min_dist = 1e10 ;
722722 int parent = -1 ;
723723
724- // 在已处理的k点中找最近邻
724+ // find the nearest k-point as parent
725725 for (int j = 0 ; j < k_order.size (); ++j) {
726726 int processed_k = k_order[j];
727727 double dist = (kvecs_c[current_k] - kvecs_c[processed_k]).norm2 ();
@@ -744,14 +744,14 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
744744 // Get k-point difference
745745 ModuleBase::Vector3<double > dk = kvecs_c[to_ik] - kvecs_c[from_ik];
746746
747- // Allocate temporary arrays
748- std::vector<T> psi_real (this ->wfc_basis ->nrxx );
747+ // Allocate temporary arrays using device-aware memory management
748+ T* porter = nullptr ;
749+ resmem_complex_op ()(this ->ctx , porter, this ->wfc_basis ->nmaxgr , " HSolverPW::porter" );
749750
750751 // Process each band
751752 for (int ib = 0 ; ib < nbands; ++ib) {
752- // IFFT to real space
753- // TODO: Check if the call is correct
754- this ->wfc_basis ->recip_to_real (this ->ctx , &psi (from_ik, ib, 0 ), psi_real.data (), from_ik);
753+ // FFT to real space
754+ this ->wfc_basis ->recip_to_real (this ->ctx , &psi (from_ik, ib, 0 ), porter, from_ik);
755755
756756 // Apply phase factor
757757 // // TODO: Check how to get the r vector
@@ -761,10 +761,11 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
761761 // }
762762
763763 // FFT back to reciprocal space
764- // TODO: Check if the call is correct
765-
766- this ->wfc_basis ->real_to_recip (this ->ctx , psi_real.data (), psi.get_pointer (ib), to_ik);
764+ this ->wfc_basis ->real_to_recip (this ->ctx , porter, &psi (to_ik, ib, 0 ), to_ik, true );
767765 }
766+
767+ // Clean up temporary arrays
768+ delmem_complex_op ()(this ->ctx , porter);
768769}
769770
770771template class HSolverPW <std::complex <float >, base_device::DEVICE_CPU>;
0 commit comments