22
33#include " module_base/macros.h"
44#include " module_base/memory.h"
5+ #include " module_base/parallel_device.h"
56#include " module_base/timer.h"
67#include " module_base/tool_quit.h"
78#include " module_hsolver/diago_iter_assist.h"
@@ -40,8 +41,8 @@ void PSIInit<T, Device>::prepare_init(const int& random_seed)
4041 // use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
4142 ModuleBase::timer::tick (" PSIInit" , " prepare_init" );
4243 this ->psi_initer .reset ();
43- if (this ->init_wfc == " random" || (PARAM. inp . ks_solver == " bpcg " && PARAM. inp . bndpar > 1 ) )
44- { // temporary solution for band parallel bpcg
44+ if (this ->init_wfc == " random" )
45+ {
4546 this ->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_random<T>());
4647 }
4748 else if (this ->init_wfc == " file" )
@@ -97,30 +98,34 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
9798 ModuleBase::timer::tick (" PSIInit" , " initialize_psi" );
9899
99100 const int nbands_start = this ->psi_initer ->nbands_start ();
100- const int nbands = psi->get_nbands ();
101+ const int nbands_l = psi->get_nbands ();
101102 const int nbasis = psi->get_nbasis ();
102- const bool not_equal = (nbands_start != nbands );
103+ const bool not_equal = (nbands_start != nbands_l );
103104
104105 Psi<T>* psi_cpu = reinterpret_cast <psi::Psi<T>*>(psi);
105106 Psi<T, Device>* psi_device = kspw_psi;
106107
107- if (not_equal)
108- {
109- psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nbasis, true );
110- psi_device = PARAM.inp .device == " gpu" ? new psi::Psi<T, Device>(psi_cpu[0 ])
111- : reinterpret_cast <psi::Psi<T, Device>*>(psi_cpu);
112- }
113- else if (PARAM.inp .precision == " single" )
108+ bool fill = PARAM.inp .ks_solver != " bpcg" || GlobalV::MY_BNDGROUP == 0 ;
109+ if (fill)
114110 {
115- if (PARAM. inp . device == " cpu " )
111+ if (not_equal )
116112 {
117- psi_cpu = reinterpret_cast <psi::Psi<T>*>(kspw_psi);
118- psi_device = kspw_psi;
113+ psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nbasis, true );
114+ psi_device = PARAM.inp .device == " gpu" ? new psi::Psi<T, Device>(psi_cpu[0 ])
115+ : reinterpret_cast <psi::Psi<T, Device>*>(psi_cpu);
119116 }
120- else
117+ else if (PARAM. inp . precision == " single " )
121118 {
122- psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nbasis, true );
123- psi_device = kspw_psi;
119+ if (PARAM.inp .device == " cpu" )
120+ {
121+ psi_cpu = reinterpret_cast <psi::Psi<T>*>(kspw_psi);
122+ psi_device = kspw_psi;
123+ }
124+ else
125+ {
126+ psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nbasis, true );
127+ psi_device = kspw_psi;
128+ }
124129 }
125130 }
126131
@@ -134,58 +139,90 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
134139
135140 // ! Update Hamiltonian from other kpoint to the given one
136141 p_hamilt->updateHk (ik);
137-
138- // ! initialize psi_cpu
139- this ->psi_initer ->init_psig (psi_cpu->get_pointer (), ik);
140- if (psi_device->get_pointer () != psi_cpu->get_pointer ())
142+ if (fill)
141143 {
142- syncmem_h2d_op ()(psi_device->get_pointer (), psi_cpu->get_pointer (), nbands_start * nbasis);
143- }
144-
145- std::vector<typename GetTypeReal<T>::type> etatom (nbands_start, 0.0 );
144+ // ! initialize psi_cpu
145+ this ->psi_initer ->init_psig (psi_cpu->get_pointer (), ik);
146+ if (psi_device->get_pointer () != psi_cpu->get_pointer ())
147+ {
148+ syncmem_h2d_op ()(psi_device->get_pointer (), psi_cpu->get_pointer (), nbands_start * nbasis);
149+ }
146150
147- if (this ->ks_solver == " cg" )
148- {
149- if (not_equal)
151+ if (this ->ks_solver == " cg" )
150152 {
151- // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different
152- hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
153- psi_device->get_pointer (),
154- nbands_start,
155- nbasis,
156- *(kspw_psi),
157- etatom.data ());
153+ std::vector<typename GetTypeReal<T>::type> etatom (nbands_start, 0.0 );
154+ if (not_equal)
155+ {
156+ // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be
157+ // different
158+ hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
159+ psi_device->get_pointer (),
160+ nbands_start,
161+ nbasis,
162+ *(kspw_psi),
163+ etatom.data ());
164+ }
165+ else
166+ {
167+ // for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
168+ hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt,
169+ *psi_device,
170+ *kspw_psi,
171+ etatom.data (),
172+ nbands_start);
173+ }
158174 }
159- else
175+ else // dav, bpcg
160176 {
161- // for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
162- hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt,
163- *psi_device,
164- *kspw_psi,
165- etatom.data (),
166- nbands_start);
177+ if (psi_device->get_pointer () != kspw_psi->get_pointer ())
178+ {
179+ syncmem_complex_op ()(kspw_psi->get_pointer (), psi_device->get_pointer (), nbands_l * nbasis);
180+ }
167181 }
168182 }
169- else // dav, bpcg
183+ #ifdef __MPI
184+ if (PARAM.inp .ks_solver == " bpcg" && PARAM.inp .bndpar > 1 )
170185 {
171- if (psi_device->get_pointer () != kspw_psi->get_pointer ())
186+ std::vector<int > sendcounts (PARAM.inp .bndpar );
187+ std::vector<int > displs (PARAM.inp .bndpar );
188+ MPI_Allgather (&nbands_l, 1 , MPI_INT, sendcounts.data (), 1 , MPI_INT, BP_WORLD);
189+ displs[0 ] = 0 ;
190+ sendcounts[0 ] *= nbasis;
191+ for (int i = 1 ; i < PARAM.inp .bndpar ; i++)
172192 {
173- syncmem_complex_op ()(kspw_psi->get_pointer (), psi_device->get_pointer (), nbands * nbasis);
193+ sendcounts[i] *= nbasis;
194+ displs[i] = displs[i - 1 ] + sendcounts[i - 1 ];
174195 }
196+ if (GlobalV::MY_BNDGROUP == 0 )
197+ {
198+ for (int ip = 1 ; ip < PARAM.inp .bndpar ; ++ip)
199+ {
200+ Parallel_Common::send_data (psi_cpu->get_pointer () + displs[ip], sendcounts[ip], ip, 0 , BP_WORLD);
201+ }
202+ }
203+ else
204+ {
205+ MPI_Status status;
206+ Parallel_Common::recv_dev<T, Device>(kspw_psi->get_pointer (), nbands_l * nbasis, 0 , 0 , BP_WORLD, &status);
207+ }
175208 }
209+ #endif
176210 } // end k-point loop
177211
178- if (not_equal )
212+ if (fill )
179213 {
180- delete psi_cpu;
181- if (PARAM.inp .device == " gpu" )
214+ if (not_equal)
182215 {
183- delete psi_device;
216+ delete psi_cpu;
217+ if (PARAM.inp .device == " gpu" )
218+ {
219+ delete psi_device;
220+ }
221+ }
222+ else if (PARAM.inp .precision == " single" && PARAM.inp .device == " gpu" )
223+ {
224+ delete psi_cpu;
184225 }
185- }
186- else if (PARAM.inp .precision == " single" && PARAM.inp .device == " gpu" )
187- {
188- delete psi_cpu;
189226 }
190227
191228 ModuleBase::timer::tick (" PSIInit" , " initialize_psi" );
@@ -203,7 +240,11 @@ void PSIInit<T, Device>::initialize_lcao_in_pw(Psi<T>* psi_local, std::ofstream&
203240 }
204241}
205242
206- void allocate_psi (Psi<std::complex <double >>*& psi, const int & nks, const std::vector<int >& ngk, const int & nbands, const int & npwx)
243+ void allocate_psi (Psi<std::complex <double >>*& psi,
244+ const int & nks,
245+ const std::vector<int >& ngk,
246+ const int & nbands,
247+ const int & npwx)
207248{
208249 assert (npwx > 0 );
209250 assert (nks > 0 );
0 commit comments