Skip to content

Commit 2156ecc

Browse files
committed
fix device for cpu & gpu
1 parent e330b57 commit 2156ecc

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void HSolverPW<T, Device>::call_paw_cell_set_currentk(const int ik)
9696
}
9797

9898
template <typename T, typename Device>
99-
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi,
99+
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T>& psi,
100100
elecstate::ElecState* pes,
101101
const double tpiba,
102102
const int nat)
@@ -744,13 +744,18 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
744744
// Get k-point difference
745745
ModuleBase::Vector3<double> dk = kvecs_c[to_ik] - kvecs_c[from_ik];
746746

747-
// Allocate temporary arrays using device-aware memory management
747+
// Allocate porter locally
748748
T* porter = nullptr;
749749
resmem_complex_op()(this->ctx, porter, this->wfc_basis->nmaxgr, "HSolverPW::porter");
750750

751751
// Process each band
752-
for (int ib = 0; ib < nbands; ++ib) {
752+
for (int ib = 0; ib < nbands; ib++)
753+
{
754+
// Fix current k-point and band
755+
// psi.fix_k(from_ik);
756+
753757
// FFT to real space
758+
// this->wfc_basis->recip_to_real(this->ctx, psi.get_pointer(ib), porter, from_ik);
754759
this->wfc_basis->recip_to_real(this->ctx, &psi(from_ik, ib, 0), porter, from_ik);
755760

756761
// Apply phase factor
@@ -760,11 +765,15 @@ void HSolverPW<T, Device>::propagate_psi(psi::Psi<T>& psi, const int from_ik, co
760765
// psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
761766
// }
762767

768+
// Fix k-point for target
769+
// psi.fix_k(to_ik);
770+
763771
// FFT back to reciprocal space
764-
this->wfc_basis->real_to_recip(this->ctx, porter, &psi(to_ik, ib, 0), to_ik, true);
772+
// this->wfc_basis->real_to_recip(this->ctx, porter, psi.get_pointer(ib), to_ik, true);
773+
this->wfc_basis->real_to_recip(this->ctx, porter, &psi(to_ik, ib, 0), to_ik);
765774
}
766-
767-
// Clean up temporary arrays
775+
776+
// Clean up porter
768777
delmem_complex_op()(this->ctx, porter);
769778
}
770779

source/module_hsolver/hsolver_pw.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "module_basis/module_pw/pw_basis_k.h"
88
#include "module_psi/wavefunc.h"
99
#include <unordered_map>
10+
#include "module_base/memory.h"
1011

1112
namespace hsolver
1213
{
@@ -19,6 +20,9 @@ class HSolverPW
1920
// return T if T is real type(float, double),
2021
// otherwise return the real type of T(complex<float>, complex<double>)
2122
using Real = typename GetTypeReal<T>::type;
23+
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
24+
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
25+
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
2226

2327
public:
2428
HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
@@ -52,6 +56,7 @@ class HSolverPW
5256
const double tpiba,
5357
const int nat);
5458

59+
5560
protected:
5661
// diago caller
5762
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,

0 commit comments

Comments
 (0)