Skip to content

Commit bec10a2

Browse files
committed
set npol to private
1 parent 2ebc9bb commit bec10a2

File tree

10 files changed

+24
-23
lines changed

10 files changed

+24
-23
lines changed

source/module_elecstate/elecstate_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
271271
{
272272
const T one{1, 0};
273273
const T zero{0, 0};
274-
const int npol = psi.npol;
274+
const int npol = psi.get_npol();
275275
const int npwx = psi.get_nbasis() / npol;
276276
const int nbands = psi.get_nbands() * npol;
277277
const int nkb = this->ppcell->nkb;

source/module_hamilt_general/operator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6363
delete this->hpsi;
6464
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
6565
1,
66-
nbands / psi_input->npol,
66+
nbands / psi_input->get_npol(),
6767
psi_input->get_nbasis(),
6868
psi_input->get_nbasis(),
6969
true);
@@ -86,7 +86,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
8686
default:
8787
op->act(nbands,
8888
psi_input->get_nbasis(),
89-
psi_input->npol,
89+
psi_input->get_npol(),
9090
tmpsi_in,
9191
this->hpsi->get_pointer(),
9292
psi_input->get_current_nbas(),
@@ -105,7 +105,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
105105
}
106106
ModuleBase::timer::tick("Operator", "hPsi");
107107

108-
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
108+
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->get_npol()), hpsi_pointer);
109109
}
110110

111111
template <typename T, typename Device>

source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mi_pw()
6666
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi_t = static_cast<psi::Psi<std::complex<double>, base_device::DEVICE_CPU>*>(this->psi);
6767
const int nbands = psi_t->get_nbands();
6868
const int nks = psi_t->get_nk();
69-
const int npol = psi_t->npol;
69+
const int npol = psi_t->get_npol();
7070
for(int ik = 0; ik < nks; ik++)
7171
{
7272
psi_t->fix_k(ik);
@@ -112,7 +112,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mi_pw()
112112
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi_t = static_cast<psi::Psi<std::complex<double>, base_device::DEVICE_GPU>*>(this->psi);
113113
const int nbands = psi_t->get_nbands();
114114
const int nks = psi_t->get_nk();
115-
const int npol = psi_t->npol;
115+
const int npol = psi_t->get_npol();
116116
for(int ik = 0; ik < nks; ik++)
117117
{
118118
psi_t->fix_k(ik);

source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mw_from_lambda(int
199199
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>*>(this->p_hamilt);
200200
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_CPU>::get_instance();
201201
nbands = psi_t->get_nbands();
202-
npol = psi_t->npol;
202+
npol = psi_t->get_npol();
203203
nkb = onsite_p->get_tot_nproj();
204204
nk = psi_t->get_nk();
205205
nh_iat = &onsite_p->get_nh(0);
@@ -252,7 +252,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mw_from_lambda(int
252252
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>*>(this->p_hamilt);
253253
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_GPU>::get_instance();
254254
nbands = psi_t->get_nbands();
255-
npol = psi_t->npol;
255+
npol = psi_t->get_npol();
256256
nkb = onsite_p->get_tot_nproj();
257257
nk = psi_t->get_nk();
258258
nh_iat = &onsite_p->get_nh(0);
@@ -382,7 +382,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::update_psi_charge(const
382382
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>*>(this->p_hamilt);
383383
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_CPU>::get_instance();
384384
nbands = psi_t->get_nbands();
385-
npol = psi_t->npol;
385+
npol = psi_t->get_npol();
386386
nkb = onsite_p->get_tot_nproj();
387387
nk = psi_t->get_nk();
388388
nh_iat = &onsite_p->get_nh(0);
@@ -454,7 +454,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::update_psi_charge(const
454454
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>*>(this->p_hamilt);
455455
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_GPU>::get_instance();
456456
nbands = psi_t->get_nbands();
457-
npol = psi_t->npol;
457+
npol = psi_t->get_npol();
458458
nkb = onsite_p->get_tot_nproj();
459459
nk = psi_t->get_nk();
460460
nh_iat = &onsite_p->get_nh(0);

source/module_hamilt_lcao/module_dftu/dftu_pw.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr
2929
psi_p->fix_k(ik);
3030
onsite_p->tabulate_atomic(ik);
3131

32-
onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer());
32+
onsite_p->overlap_proj_psi(nbands * psi_p->get_npol(), psi_p->get_pointer());
3333
const std::complex<double>* becp = onsite_p->get_h_becp();
3434
// becp(nbands*npol , nkb)
3535
// mag = wg * \sum_{nh}becp * becp
36-
int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol;
36+
int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol();
3737
int begin_ih = 0;
3838
for(int iat = 0; iat < cell.nat; iat++)
3939
{
@@ -88,11 +88,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr
8888
psi_p->fix_k(ik);
8989
onsite_p->tabulate_atomic(ik);
9090

91-
onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer());
91+
onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer());
9292
const std::complex<double>* becp = onsite_p->get_h_becp();
9393
// becp(nbands*npol , nkb)
9494
// mag = wg * \sum_{nh}becp * becp
95-
int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol;
95+
int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol();
9696
int begin_ih = 0;
9797
for(int iat = 0; iat < cell.nat; iat++)
9898
{

source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ void projectors::OnsiteProjector<T, Device>::init(const std::string& orbital_dir
165165
RadialProjection::RadialProjector::_build_backward_map(it2iproj, lproj, irow2it_, irow2iproj_, irow2m_);
166166
RadialProjection::RadialProjector::_build_forward_map(it2ia, it2iproj, lproj, itiaiprojm2irow_);
167167
//rp_._build_sbt_tab(rgrid, projs, lproj, nq, dq);
168-
rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.npol, tab, nhtol);
168+
rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.get_npol(), tab, nhtol);
169169
// For being compatible with present cal_force and cal_stress framework
170170
// uncomment the following code block if you want to use the Onsite_Proj_tools
171171
if(this->tab_atomic_ == nullptr)
@@ -541,7 +541,7 @@ void projectors::OnsiteProjector<T, Device>::cal_occupations(const psi::Psi<std:
541541
}
542542
// std::cout << __FILE__ << ":" << __LINE__ << " nbands = " << nbands << std::endl;
543543
this->overlap_proj_psi(
544-
nbands * psi_in->npol,
544+
nbands * psi_in->get_npol(),
545545
psi_in->get_pointer());
546546
const std::complex<double>* becp_p = this->get_h_becp();
547547
// becp(nbands*npol , nkb)

source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ void Velocity::act
5151

5252
const int npw = psi_in->get_current_nbas();
5353

54-
const int max_npw = psi_in->get_nbasis() / psi_in->npol;
55-
const int npol = psi_in->npol;
54+
const int max_npw = psi_in->get_nbasis() / psi_in->get_npol();
55+
const int npol = psi_in->get_npol();
5656
const std::complex<double>* tmpsi_in = psi0;
5757
std::complex<double>* tmhpsi = vpsi;
5858
// -------------

source/module_io/write_vxc_lip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ namespace ModuleIO
161161
// psi::Psi<T> hpsi_single_band(&hpsi_localxc(ik, ib, 0), 1, 1, hpsi_localxc.get_current_nbas());
162162
// vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ngk(ik));
163163
// }
164-
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik));
164+
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.get_npol(), &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik));
165165
delete vxcs_op_pw;
166166
std::vector<T> vxc_local_k_mo = psi_Hpsi(&psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_nbasis(), psi_pw.get_nbands());
167167
Parallel_Reduce::reduce_pool(vxc_local_k_mo.data(), nbands * nbands);

source/module_psi/psi.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ template <typename T, typename Device>
191191
Psi<T, Device>::Psi(const Psi& psi_in)
192192
{
193193
this->ngk = psi_in.ngk;
194-
this->npol = psi_in.npol;
194+
this->npol = PARAM.globalv.npol;
195195
this->nk = psi_in.get_nk();
196196
this->nbands = psi_in.get_nbands();
197197
this->nbasis = psi_in.get_nbasis();
@@ -218,7 +218,7 @@ template <typename T_in, typename Device_in>
218218
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
219219
{
220220
this->ngk = psi_in.get_ngk_pointer();
221-
this->npol = psi_in.npol;
221+
this->npol = PARAM.globalv.npol;
222222
this->nk = psi_in.get_nk();
223223
this->nbands = psi_in.get_nbands();
224224
this->nbasis = psi_in.get_nbasis();

source/module_psi/psi.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,16 @@ class Psi
134134

135135
const int& get_current_ngk() const;
136136

137+
const int& get_npol() const {return this->npol;}
138+
137139
// solve Range: return(pointer of begin, number of bands or k-points)
138140
std::tuple<const T*, int> to_range(const Range& range) const;
139141

140-
int npol = 1;
141-
142142
private:
143143
T* psi = nullptr; // avoid using C++ STL
144144

145145
Device* ctx = {}; // an context identifier for obtaining the device variable
146+
int npol = 1;
146147

147148
// dimensions
148149
int nk = 1; // number of k points

0 commit comments

Comments
 (0)