@@ -96,7 +96,7 @@ void HSolverPW<T, Device>::call_paw_cell_set_currentk(const int ik)
9696}
9797
9898template <typename T, typename Device>
99- void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device >& psi,
99+ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T>& psi,
100100 elecstate::ElecState* pes,
101101 const double tpiba,
102102 const int nat)
@@ -744,13 +744,18 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
744744 // Get k-point difference
745745 ModuleBase::Vector3<double > dk = kvecs_c[to_ik] - kvecs_c[from_ik];
746746
747- // Allocate temporary arrays using device-aware memory management
747+ // Allocate porter locally
748748 T* porter = nullptr ;
749749 resmem_complex_op ()(this ->ctx , porter, this ->wfc_basis ->nmaxgr , " HSolverPW::porter" );
750750
751751 // Process each band
752- for (int ib = 0 ; ib < nbands; ++ib) {
752+ for (int ib = 0 ; ib < nbands; ib++)
753+ {
754+ // Fix current k-point and band
755+ // psi.fix_k(from_ik);
756+
753757 // FFT to real space
758+ // this->wfc_basis->recip_to_real(this->ctx, psi.get_pointer(ib), porter, from_ik);
754759 this ->wfc_basis ->recip_to_real (this ->ctx , &psi (from_ik, ib, 0 ), porter, from_ik);
755760
756761 // Apply phase factor
@@ -760,11 +765,15 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
760765 // psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
761766 // }
762767
768+ // Fix k-point for target
769+ // psi.fix_k(to_ik);
770+
763771 // FFT back to reciprocal space
764- this ->wfc_basis ->real_to_recip (this ->ctx , porter, &psi (to_ik, ib, 0 ), to_ik, true );
772+ // this->wfc_basis->real_to_recip(this->ctx, porter, psi.get_pointer(ib), to_ik, true);
773+ this ->wfc_basis ->real_to_recip (this ->ctx , porter, &psi (to_ik, ib, 0 ), to_ik);
765774 }
766-
767- // Clean up temporary arrays
775+
776+ // Clean up porter
768777 delmem_complex_op ()(this ->ctx , porter);
769778}
770779
0 commit comments