@@ -116,150 +116,161 @@ void OperatorEXXPW<T, Device>::construct_ace() const
116116 ModuleBase::timer::tick (" OperatorEXXPW" , " construct_ace" );
117117
118118 int nk_max = kv->para_k .get_max_nks_pool ();
119- for (int ik = 0 ; ik < nk_max; ik++)
119+ int nspin_fac = PARAM.inp .nspin == 2 ? 2 : 1 ;
120+ for (int ispin = 0 ; ispin < nspin_fac; ispin++)
120121 {
121- int npwk = wfcpw->npwk [ik];
122+ for (int ik0 = 0 ; ik0 < nk_max; ik0++)
123+ {
124+ int ik = ik0 + ispin * nk_max;
125+ int npwk = wfcpw->npwk [ik];
122126
123- T* Xi_ace = Xi_ace_k[ik];
124- psi.fix_kb (ik, 0 );
125- T* p_psi = psi.get_pointer ();
127+ T* Xi_ace = Xi_ace_k[ik];
128+ psi.fix_kb (ik, 0 );
129+ T* p_psi = psi.get_pointer ();
126130
127- setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
131+ setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
128132
129- setmem_complex_op ()(h_psi_recip, 0 , wfcpw->npwk_max );
130- setmem_complex_op ()(h_psi_real, 0 , rhopw_dev->nrxx );
131- setmem_complex_op ()(density_real, 0 , rhopw_dev->nrxx );
132- setmem_complex_op ()(density_recip, 0 , rhopw_dev->npw );
133- setmem_complex_op ()(psi_nk_real, 0 , wfcpw->nrxx );
134- setmem_complex_op ()(psi_mq_real, 0 , wfcpw->nrxx );
135- int nqs = kv->get_nkstot_full ();
133+ setmem_complex_op ()(h_psi_recip, 0 , wfcpw->npwk_max );
134+ setmem_complex_op ()(h_psi_real, 0 , rhopw_dev->nrxx );
135+ setmem_complex_op ()(density_real, 0 , rhopw_dev->nrxx );
136+ setmem_complex_op ()(density_recip, 0 , rhopw_dev->npw );
137+ setmem_complex_op ()(psi_nk_real, 0 , wfcpw->nrxx );
138+ setmem_complex_op ()(psi_mq_real, 0 , wfcpw->nrxx );
139+ int nqs = kv->get_nkstot_full ();
136140
137- bool skip_ik = false ;
138- if (ik >= wfcpw->nks )
139- {
140- skip_ik = true ;
141- }
142- if (skip_ik)
143- {
144- // ik fixed here, select band n
145- for (int iq = 0 ; iq < nqs; iq++)
141+ bool skip_ik = false ;
142+ if (ik >= wfcpw->nks )
143+ {
144+ skip_ik = true ;
145+ }
146+ if (skip_ik)
146147 {
147- // for \psi_nk, get the pw of iq and band m
148- get_exx_potential<Real, Device>(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega , ik, iq);
148+ // ik fixed here, select band n
149+ for (int iq0 = 0 ; iq0 < nqs; iq0++)
150+ {
151+ int iq = iq0 + ik;
152+ // for \psi_nk, get the pw of iq and band m
153+ get_exx_potential<Real, Device>(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega , ik, iq);
149154
150- // decide which pool does the iq belong to
151- int iq_pool = kv->para_k .whichpool [iq ];
152- int iq_loc = iq - kv->para_k .startk_pool [iq_pool];
155+ // decide which pool does the iq belong to
156+ int iq_pool = kv->para_k .whichpool [iq0 ];
157+ int iq_loc = iq - kv->para_k .startk_pool [iq_pool];
153158
154- for (int m_iband = 0 ; m_iband < psi.get_nbands (); m_iband++)
155- {
156- double wg_mqb = 0 ;
157- bool skip = false ;
158- if (iq_pool == GlobalV::MY_POOL)
159+ for (int m_iband = 0 ; m_iband < psi.get_nbands (); m_iband++)
159160 {
160- wg_mqb = (*wg)(iq_loc, m_iband);
161- }
161+ double wg_mqb = 0 ;
162+ bool skip = false ;
163+ if (iq_pool == GlobalV::MY_POOL)
164+ {
165+ wg_mqb = (*wg)(iq_loc, m_iband);
166+ }
162167
163- MPI_Bcast (&wg_mqb, 1 , MPI_DOUBLE, kv->para_k .get_startpro_pool (iq_pool), MPI_COMM_WORLD);
168+ MPI_Bcast (&wg_mqb, 1 , MPI_DOUBLE, kv->para_k .get_startpro_pool (iq_pool), MPI_COMM_WORLD);
164169
165- if (wg_mqb < 1e-12 )
166- continue ;
170+ if (wg_mqb < 1e-12 )
171+ continue ;
167172
168- if (iq_pool == GlobalV::MY_POOL)
169- {
170- const T* psi_mq = get_pw (m_iband, iq_loc);
171- wfcpw->recip_to_real (ctx, psi_mq, psi_mq_real, iq_loc);
172- // send
173- }
174- // if (iq == 0)
175- // std::cout << "Bcast psi_mq_real" << std::endl;
176- MPI_Bcast (psi_mq_real, wfcpw->nrxx , MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
173+ if (iq_pool == GlobalV::MY_POOL)
174+ {
175+ const T* psi_mq = get_pw (m_iband, iq_loc);
176+ wfcpw->recip_to_real (ctx, psi_mq, psi_mq_real, iq_loc);
177+ // send
178+ }
179+ // if (iq == 0)
180+ // std::cout << "Bcast psi_mq_real" << std::endl;
181+ MPI_Bcast (psi_mq_real, wfcpw->nrxx , MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
177182
178183
179- } // end of iq
184+ } // end of iq
180185
186+ }
181187 }
182- }
183- else
184- {
185- *ik_ = ik;
186- act_op (nbands, nbasis, 1 , p_psi, h_psi_ace, nbasis, false );
187- // psi_h_psi_ace = psi^\dagger * h_psi_ace
188- // p_exx_helper->psi.fix_kb(0, 0);
189- gemm_complex_op ()(' C' ,
190- ' N' ,
191- nbands,
192- nbands,
193- npwk,
194- &intermediate_one,
195- p_psi,
196- nbasis,
197- h_psi_ace,
198- nbasis,
199- &intermediate_zero,
200- psi_h_psi_ace,
201- nbands);
202-
203- // reduction of psi_h_psi_ace, due to distributed memory
204- Parallel_Reduce::reduce_pool (psi_h_psi_ace, nbands * nbands);
205-
206- T intermediate_minus_one = -1.0 ;
207- axpy_complex_op ()(nbands * nbands,
208- &intermediate_minus_one,
209- psi_h_psi_ace,
210- 1 ,
211- L_ace,
212- 1 );
213-
214-
215- int info = 0 ;
216- char up = ' U' , lo = ' L' ;
217- //
218- // for (int i = 0; i < nbands; ++i)
219- // {
220- // for (int j = 0; j < nbands; ++j)
221- // {
222- // {
223- // std::cout << psi_h_psi_ace[i * nbands + j] << " ";
224- // }
225- // }
226- // std::cout << std::endl;
227- // }
228- // MPI_Barrier(MPI_COMM_WORLD);
229- // MPI_Abort(MPI_COMM_WORLD, 0);
230-
231- lapack_potrf ()(lo, nbands, L_ace, nbands);
232-
233- // expand for-loop
234- for (int i = 0 ; i < nbands; ++i) {
235- setmem_complex_op ()(L_ace + i * nbands, 0 , i);
188+ else
189+ {
190+ *ik_ = ik;
191+ act_op_kpar (nbands, nbasis, 1 , p_psi, h_psi_ace, nbasis, false );
192+ // psi_h_psi_ace = psi^\dagger * h_psi_ace
193+ // p_exx_helper->psi.fix_kb(0, 0);
194+ gemm_complex_op ()(' C' ,
195+ ' N' ,
196+ nbands,
197+ nbands,
198+ npwk,
199+ &intermediate_one,
200+ p_psi,
201+ nbasis,
202+ h_psi_ace,
203+ nbasis,
204+ &intermediate_zero,
205+ psi_h_psi_ace,
206+ nbands);
207+
208+ // reduction of psi_h_psi_ace, due to distributed memory
209+ Parallel_Reduce::reduce_pool (psi_h_psi_ace, nbands * nbands);
210+
211+ T intermediate_minus_one = -1.0 ;
212+ axpy_complex_op ()(nbands * nbands,
213+ &intermediate_minus_one,
214+ psi_h_psi_ace,
215+ 1 ,
216+ L_ace,
217+ 1 );
218+
219+
220+ int info = 0 ;
221+ char up = ' U' , lo = ' L' ;
222+
223+ // for (int i = 0; i < nbands; ++i)
224+ // {
225+ // for (int j = 0; j < nbands; ++j)
226+ // {
227+ // // std::cout << L_ace[i * nbands + j]. << " ";
228+ // if (L_ace[i * nbands + j].imag() >= 0.0)
229+ // {
230+ // std::cout << L_ace[i * nbands + j].real() << "+" << L_ace[i * nbands + j].imag() << "im ";
231+ // }
232+ // else
233+ // {
234+ // std::cout << L_ace[i * nbands + j].real() << L_ace[i * nbands + j].imag() << "im ";
235+ // }
236+ // }
237+ // std::cout << ";" << std::endl;
238+ // }
239+ // MPI_Barrier(MPI_COMM_WORLD);
240+ // MPI_Abort(MPI_COMM_WORLD, 0);
241+
242+ lapack_potrf ()(lo, nbands, L_ace, nbands);
243+
244+ // expand for-loop
245+ for (int i = 0 ; i < nbands; ++i) {
246+ setmem_complex_op ()(L_ace + i * nbands, 0 , i);
247+ }
248+
249+ // L_ace inv in place
250+ char non = ' N' ;
251+ lapack_trtri ()(lo, non, nbands, L_ace, nbands);
252+
253+ // Xi_ace = L_ace^-1 * h_psi_ace^dagger
254+ gemm_complex_op ()(' N' ,
255+ ' C' ,
256+ nbands,
257+ npwk,
258+ nbands,
259+ &intermediate_one,
260+ L_ace,
261+ nbands,
262+ h_psi_ace,
263+ nbasis,
264+ &intermediate_zero,
265+ Xi_ace,
266+ nbands);
267+
268+ // clear mem
269+ setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
270+ setmem_complex_op ()(psi_h_psi_ace, 0 , nbands * nbands);
271+ setmem_complex_op ()(L_ace, 0 , nbands * nbands);
236272 }
237-
238- // L_ace inv in place
239- char non = ' N' ;
240- lapack_trtri ()(lo, non, nbands, L_ace, nbands);
241-
242- // Xi_ace = L_ace^-1 * h_psi_ace^dagger
243- gemm_complex_op ()(' N' ,
244- ' C' ,
245- nbands,
246- npwk,
247- nbands,
248- &intermediate_one,
249- L_ace,
250- nbands,
251- h_psi_ace,
252- nbasis,
253- &intermediate_zero,
254- Xi_ace,
255- nbands);
256-
257- // clear mem
258- setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
259- setmem_complex_op ()(psi_h_psi_ace, 0 , nbands * nbands);
260- setmem_complex_op ()(L_ace, 0 , nbands * nbands);
261273 }
262-
263274 }
264275
265276 *ik_ = ik_save;
0 commit comments