Skip to content

Commit 3301b52

Browse files
committed
moidfy back the func
1 parent ee48127 commit 3301b52

File tree

10 files changed

+36
-29
lines changed

10 files changed

+36
-29
lines changed

source/module_basis/module_pw/test/test-other.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,26 +76,26 @@ TEST_F(PWTEST,test_other)
7676
}
7777
#endif
7878

79-
pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhog1, rhor1, ik);
79+
pwktest.recip_to_real(ctx, rhog1, rhor1, ik);
8080
pwktest.recip2real(rhog2, rhor2, ik);
8181
for(int ir = 0 ; ir < nrxx; ++ir)
8282
{
8383
EXPECT_NEAR(std::abs(rhor1[ir]),std::abs(rhor2[ir]),1e-8);
8484
}
85-
pwktest.real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(rhor1, rhog1, ik);
85+
pwktest.real_to_recip(ctx, rhor1, rhog1, ik);
8686
pwktest.real2recip(rhor2, rhog2, ik);
8787
for(int ig = 0 ; ig < npwk; ++ig)
8888
{
8989
EXPECT_NEAR(std::abs(rhog1[ig]),std::abs(rhog2[ig]),1e-8);
9090
}
9191
#ifdef __ENABLE_FLOAT_FFTW
92-
pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhofg1, rhofr1, ik);
92+
pwktest.recip_to_real(ctx, rhofg1, rhofr1, ik);
9393
pwktest.recip2real(rhofg2, rhofr2, ik);
9494
for(int ir = 0 ; ir < nrxx; ++ir)
9595
{
9696
EXPECT_NEAR(std::abs(rhofr1[ir]),std::abs(rhofr2[ir]),1e-6);
9797
}
98-
pwktest.real_to_recip<std::complex<float>,base_device::DEVICE_CPU>(rhofr1, rhofg1, ik);
98+
pwktest.real_to_recip(ctx, rhofr1, rhofg1, ik);
9999
pwktest.real2recip(rhofr2, rhofg2, ik);
100100
for(int ig = 0 ; ig < npwk; ++ig)
101101
{

source/module_elecstate/elecstate_pw.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
202202
/// be care of when smearing_sigma is large, wg would less than 0
203203
///
204204

205-
this->basis->recip_to_real<T,Device>( &psi(ibnd,0), this->wfcr, ik);
205+
this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
206206

207-
this->basis->recip_to_real<T,Device>( &psi(ibnd,npwx), this->wfcr_another_spin, ik);
207+
this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik);
208208

209209
const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
210210

@@ -230,7 +230,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
230230
/// only occupied band should be calculated.
231231
///
232232

233-
this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik);
233+
this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
234234

235235
const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
236236

@@ -258,7 +258,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
258258
&psi(ibnd, 0),
259259
this->wfcr);
260260

261-
this->basis->recip_to_real<T,Device>( this->wfcr, this->wfcr, ik);
261+
this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik);
262262

263263
elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr);
264264
}

source/module_elecstate/elecstate_pw_cal_tau.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
2323
int nbands = psi.get_nbands();
2424
for (int ibnd = 0; ibnd < nbands; ibnd++)
2525
{
26-
this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik);
26+
this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
2727

2828
const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
2929

@@ -43,7 +43,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
4343
&psi(ibnd, 0),
4444
this->wfcr);
4545

46-
this->basis->recip_to_real<T,Device>(this->wfcr, this->wfcr, ik);
46+
this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik);
4747

4848
elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr);
4949
}

source/module_hamilt_general/module_xc/test/xc3_mock.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ namespace ModulePW
7878

7979

8080
template <typename FPTYPE, typename Device>
81-
void PW_Basis_K::real_to_recip(const std::complex<FPTYPE>* in,
81+
void PW_Basis_K::real_to_recip(const Device* ctx,
82+
const std::complex<FPTYPE>* in,
8283
std::complex<FPTYPE>* out,
8384
const int ik,
8485
const bool add,
@@ -90,7 +91,8 @@ namespace ModulePW
9091
}
9192
}
9293
template <typename FPTYPE, typename Device>
93-
void PW_Basis_K::recip_to_real(const std::complex<FPTYPE>* in,
94+
void PW_Basis_K::recip_to_real(const Device* ctx,
95+
const std::complex<FPTYPE>* in,
9496
std::complex<FPTYPE>* out,
9597
const int ik,
9698
const bool add,
@@ -102,24 +104,28 @@ namespace ModulePW
102104
}
103105
}
104106

105-
template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const std::complex<double>* in,
107+
template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx,
108+
const std::complex<double>* in,
106109
std::complex<double>* out,
107110
const int ik,
108111
const bool add,
109112
const double factor) const;
110-
template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const std::complex<double>* in,
113+
template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx,
114+
const std::complex<double>* in,
111115
std::complex<double>* out,
112116
const int ik,
113117
const bool add,
114118
const double factor) const;
115119
#if __CUDA || __ROCM
116-
template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const std::complex<double>* in,
120+
template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx,
121+
const std::complex<double>* in,
117122
std::complex<double>* out,
118123
const int ik,
119124
const bool add,
120125
const double factor) const;
121126

122-
template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const std::complex<double>* in,
127+
template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx,
128+
const std::complex<double>* in,
123129
std::complex<double>* out,
124130
const int ik,
125131
const bool add,

source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,8 @@ void XC_Functional::grad_wfc(
644644
rhog, porter.data<T>()); // Array of std::complex<double>
645645

646646
// bring the gdr from G --> R
647-
wfc_basis->recip_to_real<T,Device>(porter.data<T>(), porter.data<T>(), ik);
647+
Device * ctx = nullptr;
648+
wfc_basis->recip_to_real(ctx, porter.data<T>(), porter.data<T>(), ik);
648649

649650
xc_functional_grad_wfc_solver(
650651
ipol, wfc_basis->nrxx, // Integers

source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ void ML_data::getF_KS1(
197197
continue;
198198
}
199199

200-
pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik);
200+
pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik);
201201
const double w1 = pelec->wg(ik, ibnd) / ucell.omega;
202202

203203
// output one wf, to check KS equation
@@ -308,7 +308,7 @@ void ML_data::getF_KS2(
308308
continue;
309309
}
310310

311-
pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik);
311+
pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik);
312312
const double w1 = pelec->wg(ik, ibnd) / ucell.omega;
313313

314314
if (pelec->ekb(ik,ibnd) > epsilonM)

source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ void Meta<OperatorPW<T, Device>>::act(
6767
for (int j = 0; j < 3; j++)
6868
{
6969
meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), tmpsi_in, this->porter);
70-
wfcpw->recip_to_real<T,Device>(this->porter, this->porter, this->ik);
70+
wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik);
7171

7272
if(this->vk_col != 0) {
7373
vector_mul_vector_op()(this->vk_col, this->porter, this->porter, this->vk + current_spin * this->vk_col);
7474
}
7575

76-
wfcpw->real_to_recip<T,Device>(this->porter, this->porter, this->ik);
76+
wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik);
7777
meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), this->porter, tmhpsi, true);
7878

7979
} // x,y,z directions

source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
188188
{
189189
const T *psi_nk = tmpsi_in + n_iband * nbasis;
190190
// retrieve \psi_nk in real space
191-
wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, this->ik);
191+
wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik);
192192

193193
// for \psi_nk, get the pw of iq and band m
194194
auto q_points = get_q_points(this->ik);
@@ -208,7 +208,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
208208
// if (has_real.find({iq, m_iband}) == has_real.end())
209209
// {
210210
const T* psi_mq = get_pw(m_iband, iq);
211-
wfcpw->recip_to_real<T,Device>( psi_mq, psi_mq_real, iq);
211+
wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq);
212212
// syncmem_complex_op()(this->ctx, this->ctx, psi_all_real + m_iband * wfcpw->nrxx, psi_mq_real, wfcpw->nrxx);
213213
// has_real[{iq, m_iband}] = true;
214214
// }
@@ -271,7 +271,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
271271
} // end of iq
272272
auto h_psi_nk = tmhpsi + n_iband * nbasis;
273273
Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha;
274-
wfcpw->real_to_recip<T,Device>(h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha);
274+
wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha);
275275
setmem_complex_op()(h_psi_real, 0, rhopw->nrxx);
276276

277277
}
@@ -810,7 +810,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
810810
psi.fix_kb(ik, n_iband);
811811
const T* psi_nk = psi.get_pointer();
812812
// retrieve \psi_nk in real space
813-
wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, ik);
813+
wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, ik);
814814

815815
// for \psi_nk, get the pw of iq and band m
816816
// q_points is a vector of integers, 0 to nks-1
@@ -839,7 +839,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
839839
psi_.fix_kb(iq, m_iband);
840840
const T* psi_mq = psi_.get_pointer();
841841
// const T* psi_mq = get_pw(m_iband, iq);
842-
wfcpw->recip_to_real<T,Device>(psi_mq, psi_mq_real, iq);
842+
wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq);
843843

844844
T omega_inv = 1.0 / ucell->omega;
845845

source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ void Stochastic_Iter<T, Device>::cal_storho(const UnitCell& ucell,
642642
T* tmpout = stowf.shchi->get_pointer();
643643
for (int ichi = 0; ichi < nchip_ik; ++ichi)
644644
{
645-
wfc_basis->recip_to_real<T,Device>(tmpout, porter, ik);
645+
wfc_basis->recip_to_real(this->ctx, tmpout, porter, ik);
646646
const auto w1 = static_cast<Real>(this->pkv->wk[ik]);
647647
elecstate::elecstate_pw_op<Real, Device>()(this->ctx, current_spin, nrxx, w1, pes->rho, porter);
648648
// for (int ir = 0; ir < nrxx; ++ir)

source/module_io/get_pchg_pw.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print,
9696
<< ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
9797

9898
psi->fix_k(ik);
99-
pw_wfc->recip_to_real<std::complex<double>,Device>(&psi[0](ib, 0), wfcr.data(), ik);
99+
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
100100

101101
// To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
102102
double wg_sum_k = 0;
@@ -139,7 +139,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print,
139139
<< ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
140140

141141
psi->fix_k(ik);
142-
pw_wfc->recip_to_real<std::complex<double>,Device>( &psi[0](ib, 0), wfcr.data(), ik);
142+
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
143143

144144
double w1 = static_cast<double>(wk[ik] / ucell->omega);
145145

0 commit comments

Comments
 (0)