@@ -83,56 +83,54 @@ namespace LR
8383
8484 // convert parallel info to LibRI interfaces
8585 std::vector<std::tuple<std::set<TA>, std::set<TA>>> judge = RI_2D_Comm::get_2D_judge (this ->pmat );
86- for (int ib = 0 ;ib < nbands;++ib)
87- {
88- const int xstart_b = ib * nk * pX.get_local_size ();
89- // suppose Cs,Vs, have already been calculated in the ion-step of ground state
90- // and DM_trans has been calculated in hPsi() outside.
9186
92- // 1. set_Ds (once)
93- // convert to vector<T*> for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor)
94- std::vector<std::vector<T>> DMk_trans_vector = this ->DM_trans ->get_DMK_vector ();
95- // assert(DMk_trans_vector.size() == nk);
96- std::vector<const std::vector<T>*> DMk_trans_pointer (nk);
97- for (int ik = 0 ;ik < nk;++ik) {DMk_trans_pointer[ik] = &DMk_trans_vector[ik];}
98- // if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex<double>)
99- std::vector<std::map<TA, std::map<TAC, RI::Tensor<T>>>> Ds_trans =
100- aims_nbasis.empty () ?
101- RI_2D_Comm::split_m2D_ktoR<T>(this ->kv , DMk_trans_pointer, this ->pmat , 1 )
102- : RI_Benchmark::split_Ds (DMk_trans_vector, aims_nbasis, ucell); // 0.5 will be multiplied
103- // LR_Util::print_CV(Ds_trans[0], "Ds_trans in OperatorLREXX", 1e-10);
104- // 2. cal_Hs
105- auto lri = this ->exx_lri .lock ();
87+ // suppose Cs,Vs, have already been calculated in the ion-step of ground state
88+ // and DM_trans has been calculated in hPsi() outside.
89+
90+ // 1. set_Ds (once)
91+ // convert to vector<T*> for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor)
92+ std::vector<std::vector<T>> DMk_trans_vector = this ->DM_trans ->get_DMK_vector ();
93+ // assert(DMk_trans_vector.size() == nk);
94+ std::vector<const std::vector<T>*> DMk_trans_pointer (nk);
95+ for (int ik = 0 ;ik < nk;++ik) { DMk_trans_pointer[ik] = &DMk_trans_vector[ik]; }
96+ // if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex<double>)
97+ std::vector<std::map<TA, std::map<TAC, RI::Tensor<T>>>> Ds_trans =
98+ aims_nbasis.empty () ?
99+ RI_2D_Comm::split_m2D_ktoR<T>(this ->kv , DMk_trans_pointer, this ->pmat , 1 )
100+ : RI_Benchmark::split_Ds (DMk_trans_vector, aims_nbasis, ucell); // 0.5 will be multiplied
101+ // LR_Util::print_CV(Ds_trans[0], "Ds_trans in OperatorLREXX", 1e-10);
102+ // 2. cal_Hs
103+ auto lri = this ->exx_lri .lock ();
106104
107- // LR_Util::print_CV(Ds_trans[is], "Ds_trans in OperatorLREXX", 1e-10);
108- lri->exx_lri .set_Ds (std::move (Ds_trans[0 ]), lri->info .dm_threshold );
109- lri->exx_lri .cal_Hs ();
110- lri->Hexxs [0 ] = RI::Communicate_Tensors_Map_Judge::comm_map2_first (
111- lri->mpi_comm , std::move (lri->exx_lri .Hs ), std::get<0 >(judge[0 ]), std::get<1 >(judge[0 ]));
112- lri->post_process_Hexx (lri->Hexxs [0 ]);
105+ // LR_Util::print_CV(Ds_trans[is], "Ds_trans in OperatorLREXX", 1e-10);
106+ lri->exx_lri .set_Ds (std::move (Ds_trans[0 ]), lri->info .dm_threshold );
107+ lri->exx_lri .cal_Hs ();
108+ lri->Hexxs [0 ] = RI::Communicate_Tensors_Map_Judge::comm_map2_first (
109+ lri->mpi_comm , std::move (lri->exx_lri .Hs ), std::get<0 >(judge[0 ]), std::get<1 >(judge[0 ]));
110+ lri->post_process_Hexx (lri->Hexxs [0 ]);
113111
114- // 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point
115- // caution: parrallel
112+ // 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point
113+ // caution: parrallel
116114
117- for (int io = 0 ;io < this ->nocc ;++io)
115+ for (int io = 0 ;io < this ->nocc ;++io)
116+ {
117+ for (int iv = 0 ;iv < this ->nvirt ;++iv)
118118 {
119- for (int iv = 0 ;iv < this -> nvirt ;++iv )
119+ for (int ik = 0 ;ik < nk ;++ik )
120120 {
121- for (int ik = 0 ;ik < nk;++ik)
121+ const int xstart_bk = ik * pX.get_local_size ();
122+ this ->cal_DM_onebase (io, iv, ik); // set Ds_onebase for all e-h pairs (not only on this processor)
123+ // LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
124+ const T& ene = 2 * alpha * // minus for exchange(but here plus is right, why?), 2 for Hartree to Ry
125+ lri->exx_lri .post_2D .cal_energy (this ->Ds_onebase , lri->Hexxs [0 ]);
126+ if (this ->pX .in_this_processor (iv, io))
122127 {
123- const int xstart_bk = xstart_b + ik * pX.get_local_size ();
124- this ->cal_DM_onebase (io, iv, ik); // set Ds_onebase for all e-h pairs (not only on this processor)
125- // LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
126- const T& ene = 2 * alpha * // minus for exchange(but here plus is right, why?), 2 for Hartree to Ry
127- lri->exx_lri .post_2D .cal_energy (this ->Ds_onebase , lri->Hexxs [0 ]);
128- if (this ->pX .in_this_processor (iv, io))
129- {
130- hpsi[xstart_bk + ik * pX.get_local_size () + this ->pX .global2local_col (io) * this ->pX .get_row_size () + this ->pX .global2local_row (iv)] += ene;
131- }
128+ hpsi[xstart_bk + ik * pX.get_local_size () + this ->pX .global2local_col (io) * this ->pX .get_row_size () + this ->pX .global2local_row (iv)] += ene;
132129 }
133130 }
134131 }
135132 }
133+
136134 }
137135 template class OperatorLREXX <double >;
138136 template class OperatorLREXX <std::complex <double >>;
0 commit comments