88#include " module_basis/module_pw/pw_basis_k.h"
99#include " module_cell/unitcell.h"
1010#include " module_elecstate/elecstate.h"
11+ #include " module_elecstate/module_charge/charge.h"
1112#include " module_elecstate/module_charge/symmetry_rho.h"
1213#include " module_hamilt_pw/hamilt_pwdft/parallel_grid.h"
1314#include " module_psi/psi.h"
@@ -25,9 +26,6 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
2526 const int ny,
2627 const int nz,
2728 const int nxyz,
28- const int nks,
29- const std::vector<int >& isk,
30- const std::vector<double >& wk,
3129 const int ngmc,
3230 UnitCell* ucell,
3331 const psi::Psi<std::complex <double >>* psi,
@@ -36,165 +34,183 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
3634 const Device* ctx,
3735 const Parallel_Grid& pgrid,
3836 const std::string& global_out_dir,
39- const bool if_separate_k)
37+ const bool if_separate_k,
38+ const K_Vectors& kv,
39+ const int kpar,
40+ const int my_pool,
41+ const Charge* chg) // Charge class is needed for the charge density reduce
4042{
41- // bands_picked is a vector of 0s and 1s, where 1 means the band is picked to output
42- std::vector<int > bands_picked (nbands, 0 );
43+ // Get necessary parameters from kv
44+ const int nks = kv.get_nks (); // current process pool k-point count
45+ const int nkstot = kv.get_nkstot (); // total k-point count
4346
44- // Check if length of out_pchg is valid
45- if ( static_cast < int >(out_pchg. size ()) > nbands )
47+ // Loop over k-parallelism
48+ for ( int ip = 0 ; ip < kpar; ++ip )
4649 {
47- ModuleBase::WARNING_QUIT (" ModuleIO::get_pchg_pw" ,
48- " The number of bands specified by `out_pchg` in the "
49- " INPUT file exceeds `nbands`!" );
50- }
51-
52- // Check if all elements in bands_picked are 0 or 1
53- for (int value: out_pchg)
54- {
55- if (value != 0 && value != 1 )
50+ if (my_pool != ip)
5651 {
57- ModuleBase::WARNING_QUIT (" ModuleIO::get_pchg_pw" ,
58- " The elements of `out_pchg` must be either 0 or 1. "
59- " Invalid values found!" );
52+ continue ;
6053 }
61- }
6254
63- // Fill bands_picked with values from out_pchg
64- // Remaining bands are already set to 0
65- int length = std::min (static_cast <int >(out_pchg.size ()), nbands);
66- for (int i = 0 ; i < length; ++i)
67- {
68- // out_pchg rely on function parse_expression
69- bands_picked[i] = static_cast <int >(out_pchg[i]);
70- }
71-
72- std::vector<std::complex <double >> wfcr (nxyz);
73- std::vector<std::vector<double >> rho_band (nspin, std::vector<double >(nxyz));
55+ // bands_picked is a vector of 0s and 1s, where 1 means the band is picked to output
56+ std::vector<int > bands_picked (nbands, 0 );
7457
75- for (int ib = 0 ; ib < nbands; ++ib)
76- {
77- // Skip the loop iteration if bands_picked[ib] is 0
78- if (!bands_picked[ib])
58+ // Check if length of out_pchg is valid
59+ if (static_cast <int >(out_pchg.size ()) > nbands)
7960 {
80- continue ;
61+ ModuleBase::WARNING_QUIT (" ModuleIO::get_pchg_pw" ,
62+ " The number of bands specified by `out_pchg` in the "
63+ " INPUT file exceeds `nbands`!" );
8164 }
8265
83- for (int is = 0 ; is < nspin; ++is)
66+ // Check if all elements in bands_picked are 0 or 1
67+ for (int value: out_pchg)
8468 {
85- std::fill (rho_band[is].begin (), rho_band[is].end (), 0.0 );
69+ if (value != 0 && value != 1 )
70+ {
71+ ModuleBase::WARNING_QUIT (" ModuleIO::get_pchg_pw" ,
72+ " The elements of `out_pchg` must be either 0 or 1. "
73+ " Invalid values found!" );
74+ }
8675 }
8776
88- if (if_separate_k)
77+ // Fill bands_picked with values from out_pchg
78+ // Remaining bands are already set to 0
79+ int length = std::min (static_cast <int >(out_pchg.size ()), nbands);
80+ for (int i = 0 ; i < length; ++i)
8981 {
90- for (int ik = 0 ; ik < nks; ++ik)
91- {
92- const int spin_index = isk[ik];
93- std::cout << " Calculating band-decomposed charge density for band " << ib + 1 << " , k-point "
94- << ik % (nks / nspin) + 1 << " , spin " << spin_index + 1 << std::endl;
82+ // out_pchg rely on function parse_expression
83+ bands_picked[i] = static_cast <int >(out_pchg[i]);
84+ }
9585
96- psi-> fix_k (ik );
97- pw_wfc-> recip_to_real (ctx, &psi[ 0 ](ib, 0 ), wfcr. data (), ik );
86+ std::vector<std:: complex < double >> wfcr (nxyz );
87+ std::vector<std::vector< double >> rho_band (nspin, std::vector< double >(nxyz) );
9888
99- // To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
100- double wg_sum_k = 0 ;
101- for (int ik_tmp = 0 ; ik_tmp < nks / nspin; ++ik_tmp)
102- {
103- wg_sum_k += wk[ik_tmp];
104- }
89+ for (int ib = 0 ; ib < nbands; ++ib)
90+ {
91+ // Skip the loop iteration if bands_picked[ib] is 0
92+ if (!bands_picked[ib])
93+ {
94+ continue ;
95+ }
10596
106- double w1 = static_cast <double >(wg_sum_k / ucell->omega );
97+ for (int is = 0 ; is < nspin; ++is)
98+ {
99+ std::fill (rho_band[is].begin (), rho_band[is].end (), 0.0 );
100+ }
107101
108- for (int i = 0 ; i < nxyz; ++i)
102+ if (if_separate_k)
103+ {
104+ for (int ik = 0 ; ik < nks; ++ik)
109105 {
110- rho_band[spin_index][i] = std::norm (wfcr[i]) * w1;
106+ const int ikstot = kv.ik2iktot [ik]; // global k-point index
107+ const int spin_index = kv.isk [ik]; // spin index
108+ const int k_number = ikstot % (nkstot / nspin) + 1 ; // k-point number, starting from 1
109+
110+ psi->fix_k (ik);
111+ pw_wfc->recip_to_real (ctx, &psi[0 ](ib, 0 ), wfcr.data (), ik);
112+
113+ // To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
114+ double wg_sum_k = 0.0 ;
115+ if (nspin == 1 )
116+ {
117+ wg_sum_k = 2.0 ;
118+ }
119+ else if (nspin == 2 )
120+ {
121+ wg_sum_k = 1.0 ;
122+ }
123+ else
124+ {
125+ ModuleBase::WARNING_QUIT (" ModuleIO::get_pchg_pw" ,
126+ " Real space partial charge output currently do not support "
127+ " noncollinear polarized calculation (nspin = 4)!" );
128+ }
129+
130+ double w1 = static_cast <double >(wg_sum_k / ucell->omega );
131+
132+ for (int i = 0 ; i < nxyz; ++i)
133+ {
134+ rho_band[spin_index][i] = std::norm (wfcr[i]) * w1;
135+ }
136+
137+ std::stringstream ssc;
138+ ssc << global_out_dir << " BAND" << ib + 1 << " _K" << k_number << " _SPIN" << spin_index + 1
139+ << " _CHG.cube" ;
140+
141+ ModuleIO::write_vdata_palgrid (pgrid,
142+ rho_band[spin_index].data (),
143+ spin_index,
144+ nspin,
145+ 0 ,
146+ ssc.str (),
147+ 0.0 ,
148+ ucell,
149+ 11 ,
150+ 1 ,
151+ true ); // reduce_all_pool is true
111152 }
112-
113- std::cout << " Writing cube files..." ;
114-
115- std::stringstream ssc;
116- ssc << global_out_dir << " BAND" << ib + 1 << " _K" << ik % (nks / nspin) + 1 << " _SPIN" << spin_index + 1
117- << " _CHG.cube" ;
118-
119- ModuleIO::write_vdata_palgrid (pgrid,
120- rho_band[spin_index].data (),
121- spin_index,
122- nspin,
123- 0 ,
124- ssc.str (),
125- 0.0 ,
126- ucell);
127-
128- std::cout << " Complete!" << std::endl;
129153 }
130- }
131- else
132- {
133- for (int ik = 0 ; ik < nks; ++ik)
154+ else
134155 {
135- const int spin_index = isk[ik];
136- std::cout << " Calculating band-decomposed charge density for band " << ib + 1 << " , k-point "
137- << ik % (nks / nspin) + 1 << " , spin " << spin_index + 1 << std::endl;
156+ for (int ik = 0 ; ik < nks; ++ik)
157+ {
158+ const int ikstot = kv.ik2iktot [ik]; // global k-point index
159+ const int spin_index = kv.isk [ik]; // spin index
160+ const int k_number = ikstot % (nkstot / nspin) + 1 ; // k-point number, starting from 1
138161
139- psi->fix_k (ik);
140- pw_wfc->recip_to_real (ctx, &psi[0 ](ib, 0 ), wfcr.data (), ik);
162+ psi->fix_k (ik);
163+ pw_wfc->recip_to_real (ctx, &psi[0 ](ib, 0 ), wfcr.data (), ik);
141164
142- double w1 = static_cast <double >(wk[ik] / ucell->omega );
165+ double w1 = static_cast <double >(kv. wk [ik] / ucell->omega );
143166
144- for (int i = 0 ; i < nxyz; ++i)
145- {
146- rho_band[spin_index][i] += std::norm (wfcr[i]) * w1;
167+ for (int i = 0 ; i < nxyz; ++i)
168+ {
169+ rho_band[spin_index][i] += std::norm (wfcr[i]) * w1;
170+ }
147171 }
148- }
149172
150- // Symmetrize the charge density, otherwise the results are incorrect if the symmetry is on
151- std::cout << " Symmetrizing band-decomposed charge density..." << std::endl;
152- Symmetry_rho srho;
153- for (int is = 0 ; is < nspin; ++is)
154- {
155- // Use vector instead of raw pointers
156- std::vector<double *> rho_save_pointers (nspin);
157- for (int s = 0 ; s < nspin; ++s)
173+ // Reduce the charge density across all pools if kpar > 1
174+ if (kpar > 1 && chg != nullptr )
158175 {
159- rho_save_pointers[s] = rho_band[s].data ();
176+ for (int is = 0 ; is < nspin; ++is)
177+ {
178+ chg->reduce_diff_pools (rho_band[is].data ());
179+ }
160180 }
161181
162- std::vector<std::vector<std::complex <double >>> rhog (nspin, std::vector<std::complex <double >>(ngmc));
163-
164- // Convert vector of vectors to vector of pointers
165- std::vector<std::complex <double >*> rhog_pointers (nspin);
166- for (int s = 0 ; s < nspin; ++s)
182+ // Symmetrize the charge density, otherwise the results are incorrect if the symmetry is on
183+ std::cout << " Symmetrizing band-decomposed charge density..." << std::endl;
184+ Symmetry_rho srho;
185+ for (int is = 0 ; is < nspin; ++is)
167186 {
168- rhog_pointers[s] = rhog[s].data ();
187+ // Use vector instead of raw pointers
188+ std::vector<double *> rho_save_pointers (nspin);
189+ for (int s = 0 ; s < nspin; ++s)
190+ {
191+ rho_save_pointers[s] = rho_band[s].data ();
192+ }
193+
194+ std::vector<std::vector<std::complex <double >>> rhog (nspin, std::vector<std::complex <double >>(ngmc));
195+
196+ // Convert vector of vectors to vector of pointers
197+ std::vector<std::complex <double >*> rhog_pointers (nspin);
198+ for (int s = 0 ; s < nspin; ++s)
199+ {
200+ rhog_pointers[s] = rhog[s].data ();
201+ }
202+
203+ srho.begin (is, rho_save_pointers.data (), rhog_pointers.data (), ngmc, nullptr , pw_rhod, ucell->symm );
169204 }
170205
171- srho.begin (is,
172- rho_save_pointers.data (),
173- rhog_pointers.data (),
174- ngmc,
175- nullptr ,
176- pw_rhod,
177- ucell->symm );
178- }
179-
180- std::cout << " Writing cube files..." ;
206+ for (int is = 0 ; is < nspin; ++is)
207+ {
208+ std::stringstream ssc;
209+ ssc << global_out_dir << " BAND" << ib + 1 << " _SPIN" << is + 1 << " _CHG.cube" ;
181210
182- for (int is = 0 ; is < nspin; ++is)
183- {
184- std::stringstream ssc;
185- ssc << global_out_dir << " BAND" << ib + 1 << " _SPIN" << is + 1 << " _CHG.cube" ;
186-
187- ModuleIO::write_vdata_palgrid (pgrid,
188- rho_band[is].data (),
189- is,
190- nspin,
191- 0 ,
192- ssc.str (),
193- 0.0 ,
194- ucell);
211+ ModuleIO::write_vdata_palgrid (pgrid, rho_band[is].data (), is, nspin, 0 , ssc.str (), 0.0 , ucell);
212+ }
195213 }
196-
197- std::cout << " Complete!" << std::endl;
198214 }
199215 }
200216}
0 commit comments