@@ -70,17 +70,18 @@ void transfer_hr_gint_to_hR(const HContainer<T>& hr_gint, HContainer<T>& hR)
7070}
7171
7272// hRgint_tmp to hR
73- void transfer_hr_gint_to_hR_nspin4 (std::vector<HContainer<double >>& hRGint_tmp,
73+ void merge_hR_n4 (std::vector<HContainer<double >>& hRGint_tmp,
7474 HContainer<std::complex <double >>& hR,
7575 const GintInfo& gint_info)
7676{
77- ModuleBase::TITLE (" Gint" , " transfer_hr_gint_to_hR_nspin4 " );
78- ModuleBase::timer::tick (" Gint" , " transfer_hr_gint_to_hR_nspin4 " );
77+ ModuleBase::TITLE (" Gint" , " merge_hR_n4 " );
78+ ModuleBase::timer::tick (" Gint" , " merge_hR_n4 " );
7979#ifdef __MPI
8080 int mg = hR.get_paraV ()->get_global_row_size ()/2 ;
8181 int ng = hR.get_paraV ()->get_global_col_size ()/2 ;
8282 int nb = hR.get_paraV ()->get_block_size ()/2 ;
8383 int blacs_ctxt = hR.get_paraV ()->blacs_ctxt ;
84+
8485 const UnitCell* ucell = gint_info.get_ucell ();
8586 int *iat2iwt = new int [ucell->nat ];
8687 for (int iat = 0 ; iat < ucell->nat ; iat++) {
@@ -91,91 +92,49 @@ void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
9192 pv->set_atomic_trace (iat2iwt, ucell->nat , mg);
9293 auto ijr_info = hR.get_ijr_info ();
9394
94- hamilt::HContainer<double >* hR_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
95+ auto * hR_tmp = new hamilt::HContainer<std::complex <double >>(pv, nullptr , &ijr_info);
96+
97+ std::vector<int > first = {0 , 1 , 1 , 0 };
98+ std::vector<int > second= {3 , 2 , 2 , 3 };
99+ std::vector<int > row_set = {0 , 0 , 1 , 1 };
100+ std::vector<int > col_set = {0 , 1 , 0 , 1 };
101+ std::vector<int > clx_i = {1 , 0 , 0 , -1 };
102+ std::vector<int > clx_j = {0 , 1 , -1 , 0 };
95103 for (int is = 0 ; is < 4 ; is++){
96- hR_tmp->set_zero ();
97- // std::cout<<"is: "<<is<<std::endl;
98- hamilt::transferSerials2Parallels ( hRGint_tmp[is], hR_tmp);
99- for (int iap = 0 ; iap < hR.size_atom_pairs (); iap++)
104+ hamilt::HContainer<std::complex <double >>* hRGint_tmpCd = new hamilt::HContainer<std::complex <double >>(ucell->nat );
105+ ijr_info = hRGint_tmp[0 ].get_ijr_info ();
106+ hRGint_tmpCd->insert_ijrs (&ijr_info, *(ucell));
107+ hRGint_tmpCd->allocate (nullptr , true );
108+ hRGint_tmpCd->set_zero ();
109+ for (int iap = 0 ; iap < hRGint_tmpCd->size_atom_pairs (); iap++)
100110 {
101111 // std::cout<<"iap: "<<iap<<std::endl;
102- auto * ap = &hR. get_atom_pair (iap);
112+ auto * ap = &hRGint_tmpCd-> get_atom_pair (iap);
103113 const int iat1 = ap->get_atom_i ();
104114 const int iat2 = ap->get_atom_j ();
105- const hamilt::AtomPair<double >* ap_nspin = nullptr ;
106115 if (iat1 <= iat2)
107116 {
108117 hamilt::AtomPair<std::complex <double >>* upper_ap = ap;
109- hamilt::AtomPair<std::complex <double >>* lower_ap = hR.find_pair (iat2, iat1);
110- switch (is)
111- {
112- case 0 :
113- ap_nspin = hR_tmp->find_pair (iat1, iat2);
114- break ;
115- case 3 :
116- ap_nspin = hR_tmp->find_pair (iat1, iat2);
117- break ;
118- }
119- if (ap_nspin == nullptr ) break ;
118+ hamilt::AtomPair<std::complex <double >>* lower_ap = hRGint_tmpCd->find_pair (iat2, iat1);
119+ const hamilt::AtomPair<double >* ap_nspin1 = hRGint_tmp[first[is]].find_pair (iat1, iat2);
120+ const hamilt::AtomPair<double >* ap_nspin2 = hRGint_tmp[second[is]].find_pair (iat1, iat2);
120121 for (int ir = 0 ; ir < upper_ap->get_R_size (); ir++)
121122 {
122123 const auto R_index = upper_ap->get_R_index (ir);
123124 auto upper_mat = upper_ap->find_matrix (R_index);
124- auto mat_nspin = ap_nspin ->find_matrix (R_index);
125-
125+ auto mat_nspin1 = ap_nspin1 ->find_matrix (R_index);
126+ auto mat_nspin2 = ap_nspin2-> find_matrix (R_index);
126127 // The row size and the col size of upper_matrix is double that of matrix_nspin_0
127- for (int irow = 0 ; irow < mat_nspin ->get_row_size (); ++irow)
128+ for (int irow = 0 ; irow < mat_nspin1 ->get_row_size (); ++irow)
128129 {
129- for (int icol = 0 ; icol < mat_nspin ->get_col_size (); ++icol)
130+ for (int icol = 0 ; icol < mat_nspin1 ->get_col_size (); ++icol)
130131 {
131- switch (is)
132- {
133- case 0 :
134- upper_mat->get_value (2 *irow, 2 *icol) = mat_nspin->get_value (irow, icol);
135- upper_mat->get_value (2 *irow+1 , 2 *icol+1 ) = mat_nspin->get_value (irow, icol);
136- break ;
137- case 3 :
138- upper_mat->get_value (2 *irow, 2 *icol) += mat_nspin->get_value (irow, icol);
139- upper_mat->get_value (2 *irow+1 , 2 *icol+1 ) -= mat_nspin->get_value (irow, icol);
140- break ;
141- }
142- }
143- }
144-
145- if (PARAM.globalv .domag )
146- {
147- const hamilt::AtomPair<double >* ap_nspin = nullptr ;
148- switch (is)
149- {
150- case 1 :
151- ap_nspin = hR_tmp->find_pair (iat1, iat2);
152- break ;
153- case 2 :
154- ap_nspin = hR_tmp->find_pair (iat1, iat2);
155- break ;
156- }
157- const auto mat_nspin = ap_nspin->find_matrix (R_index);
158- for (int irow = 0 ; irow < mat_nspin->get_row_size (); ++irow)
159- {
160- for (int icol = 0 ; icol < mat_nspin->get_col_size (); ++icol)
161- {
162- switch (is)
163- {
164- case 1 :
165- upper_mat->get_value (2 *irow, 2 *icol+1 ) = mat_nspin->get_value (irow, icol);
166- upper_mat->get_value (2 *irow+1 , 2 *icol) = mat_nspin->get_value (irow, icol);
167- break ;
168- case 2 :
169- upper_mat->get_value (2 *irow, 2 *icol+1 ) += std::complex <double >(0.0 , 1.0 ) * mat_nspin->get_value (irow, icol);
170- upper_mat->get_value (2 *irow+1 , 2 *icol) -= std::complex <double >(0.0 , 1.0 ) * mat_nspin->get_value (irow, icol);
171- break ;
172- }
173- }
132+ upper_mat->get_value (irow, icol) = mat_nspin1->get_value (irow, icol)
133+ + std::complex <double >(clx_i[is], clx_j[is]) * mat_nspin2->get_value (irow, icol);
174134 }
175135 }
176-
177- // fill the lower triangle matrix
178- if (is == 3 ){
136+ // fill the lower triangle matrix
137+ if (PARAM.globalv .domag ){
179138 if (iat1 < iat2)
180139 {
181140 auto lower_mat = lower_ap->find_matrix (-R_index);
@@ -191,15 +150,41 @@ void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
191150 }
192151 }
193152 }
194-
153+
154+ hR_tmp->set_zero ();
155+ hamilt::transferSerials2Parallels ( *hRGint_tmpCd, hR_tmp);
156+ for (int iap = 0 ; iap < hR.size_atom_pairs (); iap++)
157+ {
158+ auto * ap = &hR.get_atom_pair (iap);
159+ const int iat1 = ap->get_atom_i ();
160+ const int iat2 = ap->get_atom_j ();
161+ auto * ap_nspin = hR_tmp ->find_pair (iat1, iat2);
162+ for (int ir = 0 ; ir < ap->get_R_size (); ir++)
163+ {
164+ const auto R_index = ap->get_R_index (ir);
165+ auto upper_mat = ap->find_matrix (R_index);
166+ auto mat_nspin = ap_nspin->find_matrix (R_index);
167+
168+ // The row size and the col size of upper_matrix is double that of matrix_nspin_0
169+ for (int irow = 0 ; irow < mat_nspin->get_row_size (); ++irow)
170+ {
171+ for (int icol = 0 ; icol < mat_nspin->get_col_size (); ++icol)
172+ {
173+ upper_mat->get_value (2 *irow+row_set[is], 2 *icol+col_set[is]) =
174+ mat_nspin->get_value (irow, icol);
175+ }
176+ }
177+ }
178+ }
179+ delete hRGint_tmpCd;
195180 }
196181 delete[] iat2iwt;
197- delete pv;
198- delete hR_tmp;
199182#else
200183
201184#endif
202- ModuleBase::timer::tick (" Gint" , " transfer_hr_gint_to_hR_nspin4" );
185+
186+
187+ ModuleBase::timer::tick (" Gint" , " merge_hR_n4" );
203188 return ;
204189}
205190
@@ -231,6 +216,9 @@ void transfer_dm_2d_to_gint(
231216 } else // NSPIN=4 case
232217 {
233218#ifdef __MPI
219+ // is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
220+ const int row_set[4 ] = {0 , 0 , 1 , 1 };
221+ const int col_set[4 ] = {0 , 1 , 0 , 1 };
234222 int mg = dm[0 ]->get_paraV ()->get_global_row_size ()/2 ;
235223 int ng = dm[0 ]->get_paraV ()->get_global_col_size ()/2 ;
236224 int nb = dm[0 ]->get_paraV ()->get_block_size ()/2 ;
@@ -246,43 +234,20 @@ void transfer_dm_2d_to_gint(
246234 auto ijr_info = dm[0 ]->get_ijr_info ();
247235 HContainer<T>* DM2D_tmp = new hamilt::HContainer<T>(pv, nullptr , &ijr_info);
248236 // ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
249- for (int is = 0 ; is < 4 ; is++){
237+ for (int is = 0 ; is < 4 ; is++){
250238 for (int iap = 0 ; iap < dm[0 ]->size_atom_pairs (); ++iap) {
251239 auto & ap = dm[0 ]->get_atom_pair (iap);
252240 int iat1 = ap.get_atom_i ();
253241 int iat2 = ap.get_atom_j ();
254242 for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
255243 const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
256- T* tmp_pointer = DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
257- T* data_full = ap.get_pointer (ir);
258- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
259- switch (is) {// todo: It can be written more compactly
260- case 0 :
261- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
262- *(tmp_pointer)++ = data_full[icol];
263- }
264- data_full += ap.get_col_size () * 2 ;
265- break ;
266- case 1 :
267- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
268- *(tmp_pointer)++ = data_full[icol + 1 ];
269- }
270- data_full += ap.get_col_size () * 2 ;
271- break ;
272- case 2 :
273- data_full += ap.get_col_size ();
274- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
275- *(tmp_pointer)++ = data_full[icol];
276- }
277- data_full += ap.get_col_size ();
278- break ;
279- case 3 :
280- data_full += ap.get_col_size ();
281- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
282- *(tmp_pointer)++ = data_full[icol + 1 ];
283- }
284- data_full += ap.get_col_size ();
285- break ;
244+ T* matrix_out = DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
245+ T* matrix_in = ap.get_pointer (ir);
246+ for (int irow = 0 ; irow < ap.get_row_size ()/2 ; irow ++) {
247+ for (int icol = 0 ; icol < ap.get_col_size ()/2 ; icol ++) {
248+ int index_i = irow* ap.get_col_size ()/2 + icol;
249+ int index_j = (irow*2 +row_set[is]) * ap.get_col_size () + icol*2 +col_set[is];
250+ matrix_out[index_i] = matrix_in[index_j];
286251 }
287252 }
288253 }
0 commit comments