1414#include " module_parameter/parameter.h"
1515
1616void DeePKS_domain::prepare_phialpha_r (const int nlocal,
17- const int lmaxd,
18- const int inlmax,
19- const int nat,
20- const std::vector<hamilt::HContainer< double >*> phialpha ,
21- const UnitCell& ucell ,
22- const LCAO_Orbitals& orb ,
23- const Parallel_Orbitals& pv ,
24- const Grid_Driver& GridD ,
25- torch::Tensor& phialpha_r_out ,
26- torch::Tensor& R_query )
17+ const int lmaxd,
18+ const int inlmax,
19+ const int nat,
20+ const int R_size ,
21+ const std::vector<hamilt::HContainer< double >*> phialpha ,
22+ const UnitCell& ucell ,
23+ const LCAO_Orbitals& orb ,
24+ const Parallel_Orbitals& pv ,
25+ const Grid_Driver& GridD ,
26+ torch::Tensor& phialpha_r_out )
2727{
2828 ModuleBase::TITLE (" DeePKS_domain" , " prepare_phialpha_r" );
2929 ModuleBase::timer::tick (" DeePKS_domain" , " prepare_phialpha_r" );
3030 constexpr torch::Dtype dtype = torch::kFloat64 ;
3131 int nlmax = inlmax / nat;
3232 int mmax = 2 * lmaxd + 1 ;
33- auto size_R = static_cast <long >(phialpha[0 ]->size_R_loop ());
34- phialpha_r_out = torch::zeros ({size_R, nat, nlmax, nlocal, mmax}, dtype);
35- R_query = torch::zeros ({size_R, 3 }, torch::kInt32 );
36- auto accessor = phialpha_r_out.accessor <double , 5 >();
37- auto R_accessor = R_query.accessor <int , 2 >();
3833
39- for (int iR = 0 ; iR < size_R; ++iR)
40- {
41- phialpha[0 ]->loop_R (iR, R_accessor[iR][0 ], R_accessor[iR][1 ], R_accessor[iR][2 ]);
42- }
34+ phialpha_r_out = torch::zeros ({R_size, R_size, R_size, nat, nlmax, nlocal, mmax}, dtype);
35+ auto accessor = phialpha_r_out.accessor <double , 7 >();
4336
4437 DeePKS_domain::iterate_ad1 (
4538 ucell,
@@ -81,18 +74,22 @@ void DeePKS_domain::prepare_phialpha_r(const int nlocal,
8174 const int nm = 2 * L0 + 1 ;
8275 for (int m1 = 0 ; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
8376 {
84- accessor[iR][iat][nl][iw1_all][m1] += overlap->get_value (iw1, ib + m1);
77+ int iRx = DeePKS_domain::mapping_R (dR.x );
78+ int iRy = DeePKS_domain::mapping_R (dR.y );
79+ int iRz = DeePKS_domain::mapping_R (dR.z );
80+ accessor[iRx][iRy][iRz][iat][nl][iw1_all][m1]
81+ += overlap->get_value (iw1, ib + m1);
8582 }
8683 ib += nm;
8784 nl++;
8885 }
8986 }
90- } // end iw
87+ } // end iw
9188 }
9289 );
9390
9491#ifdef __MPI
95- int size = size_R * nat * nlmax * nlocal * mmax;
92+ int size = R_size * R_size * R_size * nat * nlmax * nlocal * mmax;
9693 double * data_ptr = phialpha_r_out.data_ptr <double >();
9794 Parallel_Reduce::reduce_all (data_ptr, size);
9895
@@ -101,4 +98,183 @@ void DeePKS_domain::prepare_phialpha_r(const int nlocal,
10198 ModuleBase::timer::tick (" DeePKS_domain" , " prepare_phialpha_r" );
10299 return ;
103100}
101+
102+ void DeePKS_domain::cal_vdr_precalc (const int nlocal,
103+ const int lmaxd,
104+ const int inlmax,
105+ const int nat,
106+ const int nks,
107+ const int R_size,
108+ const std::vector<int >& inl2l,
109+ const std::vector<ModuleBase::Vector3<double >>& kvec_d,
110+ const std::vector<hamilt::HContainer<double >*> phialpha,
111+ const std::vector<torch::Tensor> gevdm,
112+ const ModuleBase::IntArray* inl_index,
113+ const UnitCell& ucell,
114+ const LCAO_Orbitals& orb,
115+ const Parallel_Orbitals& pv,
116+ const Grid_Driver& GridD,
117+ torch::Tensor& vdr_precalc)
118+ {
119+ ModuleBase::TITLE (" DeePKS_domain" , " calc_vdr_precalc" );
120+ ModuleBase::timer::tick (" DeePKS_domain" , " calc_vdr_precalc" );
121+
122+ torch::Tensor vdr_pdm
123+ = torch::zeros ({R_size, R_size, R_size, nlocal, nlocal, inlmax, (2 * lmaxd + 1 ), (2 * lmaxd + 1 )},
124+ torch::TensorOptions ().dtype (torch::kFloat64 ));
125+ auto accessor = vdr_pdm.accessor <double , 8 >();
126+
127+ DeePKS_domain::iterate_ad2 (
128+ ucell,
129+ GridD,
130+ orb,
131+ false , // no trace_alpha
132+ [&](const int iat,
133+ const ModuleBase::Vector3<double >& tau0,
134+ const int ibt1,
135+ const ModuleBase::Vector3<double >& tau1,
136+ const int start1,
137+ const int nw1_tot,
138+ ModuleBase::Vector3<int > dR1,
139+ const int ibt2,
140+ const ModuleBase::Vector3<double >& tau2,
141+ const int start2,
142+ const int nw2_tot,
143+ ModuleBase::Vector3<int > dR2)
144+ {
145+ const int T0 = ucell.iat2it [iat];
146+ const int I0 = ucell.iat2ia [iat];
147+ if (phialpha[0 ]->find_matrix (iat, ibt1, dR1.x , dR1.y , dR1.z ) == nullptr
148+ || phialpha[0 ]->find_matrix (iat, ibt2, dR2.x , dR2.y , dR2.z ) == nullptr )
149+ {
150+ return ; // to next loop
151+ }
152+
153+ hamilt::BaseMatrix<double >* overlap_1 = phialpha[0 ]->find_matrix (iat, ibt1, dR1);
154+ hamilt::BaseMatrix<double >* overlap_2 = phialpha[0 ]->find_matrix (iat, ibt2, dR2);
155+ assert (overlap_1->get_col_size () == overlap_2->get_col_size ());
156+ ModuleBase::Vector3<int > dR = dR1 - dR2;
157+ int iRx = DeePKS_domain::mapping_R (dR.x );
158+ int iRy = DeePKS_domain::mapping_R (dR.y );
159+ int iRz = DeePKS_domain::mapping_R (dR.z );
160+
161+ for (int iw1 = 0 ; iw1 < nw1_tot; ++iw1)
162+ {
163+ const int iw1_all = start1 + iw1; // this is \mu
164+ const int iw1_local = pv.global2local_row (iw1_all);
165+ if (iw1_local < 0 )
166+ {
167+ continue ;
168+ }
169+ for (int iw2 = 0 ; iw2 < nw2_tot; ++iw2)
170+ {
171+ const int iw2_all = start2 + iw2; // this is \nu
172+ const int iw2_local = pv.global2local_col (iw2_all);
173+ if (iw2_local < 0 )
174+ {
175+ continue ;
176+ }
177+
178+ int ib = 0 ;
179+ for (int L0 = 0 ; L0 <= orb.Alpha [0 ].getLmax (); ++L0)
180+ {
181+ for (int N0 = 0 ; N0 < orb.Alpha [0 ].getNchi (L0); ++N0)
182+ {
183+ const int inl = inl_index[T0](I0, L0, N0);
184+ const int nm = 2 * L0 + 1 ;
185+
186+ for (int m1 = 0 ; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
187+ {
188+ for (int m2 = 0 ; m2 < nm; ++m2) // nm = 1 for s, 3 for p, 5 for d
189+ {
190+ double tmp = overlap_1->get_value (iw1, ib + m1)
191+ * overlap_2->get_value (iw2, ib + m2);
192+ accessor[iRx][iRy][iRz][iw1_all][iw2_all][inl][m1][m2]
193+ += tmp;
194+ }
195+ }
196+ ib += nm;
197+ }
198+ }
199+ } // iw2
200+ } // iw1
201+ }
202+ );
203+
204+ #ifdef __MPI
205+ const int size = R_size * R_size * R_size * nlocal * nlocal * inlmax * (2 * lmaxd + 1 ) * (2 * lmaxd + 1 );
206+ double * data_ptr = vdr_pdm.data_ptr <double >();
207+ Parallel_Reduce::reduce_all (data_ptr, size);
208+ #endif
209+
210+ // transfer v_delta_pdm to v_delta_pdm_vector
211+ int nlmax = inlmax / nat;
212+ std::vector<torch::Tensor> vdr_pdm_vector;
213+ for (int nl = 0 ; nl < nlmax; ++nl)
214+ {
215+ int nm = 2 * inl2l[nl] + 1 ;
216+ torch::Tensor vdr_pdm_sliced = vdr_pdm.slice (5 , nl, inlmax, nlmax).slice (6 , 0 , nm, 1 ).slice (7 , 0 , nm, 1 );
217+ vdr_pdm_vector.push_back (vdr_pdm_sliced);
218+ }
219+
220+ assert (vdr_pdm_vector.size () == nlmax);
221+
222+ // einsum for each nl:
223+ std::vector<torch::Tensor> vdr_vector;
224+ for (int nl = 0 ; nl < nlmax; ++nl)
225+ {
226+ vdr_vector.push_back (at::einsum (" pqrxyamn, avmn->pqrxyav" , {vdr_pdm_vector[nl], gevdm[nl]}));
227+ }
228+
229+ vdr_precalc = torch::cat (vdr_vector, -1 );
230+
231+ ModuleBase::timer::tick (" DeePKS_domain" , " calc_vdr_precalc" );
232+ return ;
233+ }
234+
235+ int DeePKS_domain::mapping_R (int R)
236+ {
237+ // R_index mapping: index(R) = 2R-1 if R > 0, index(R) = -2R if R <= 0
238+ // after mapping, the new index [0,1,2,3,4,...] -> old index [0,1,-1,2,-2,...]
239+ // This manipulation makes sure that the new index is natural number
240+ // which makes it available to be used as index in torch::Tensor
241+ int R_index = 0 ;
242+ if (R > 0 )
243+ {
244+ R_index = 2 * R - 1 ;
245+ }
246+ else
247+ {
248+ R_index = -2 * R;
249+ }
250+ return R_index;
251+ }
252+
253+ template <typename T>
254+ int DeePKS_domain::get_R_size (const hamilt::HContainer<T>& hcontainer)
255+ {
256+ // get R_size from hcontainer
257+ int R_size = 0 ;
258+ if (hcontainer.size_R_loop () > 0 )
259+ {
260+ for (int iR = 0 ; iR < hcontainer.size_R_loop (); ++iR)
261+ {
262+ ModuleBase::Vector3<int > R_vec;
263+ hcontainer.loop_R (iR, R_vec.x , R_vec.y , R_vec.z );
264+ int R_min = std::min ({R_vec.x , R_vec.y , R_vec.z });
265+ int R_max = std::max ({R_vec.x , R_vec.y , R_vec.z });
266+ int tmp_R_size = std::max (DeePKS_domain::mapping_R (R_min), DeePKS_domain::mapping_R (R_max)) + 1 ;
267+ if (tmp_R_size > R_size)
268+ {
269+ R_size = tmp_R_size;
270+ }
271+ }
272+ }
273+ assert (R_size > 0 );
274+ return R_size;
275+ }
276+
277+ template int DeePKS_domain::get_R_size<double >(const hamilt::HContainer<double >& hcontainer);
278+ template int DeePKS_domain::get_R_size<std::complex <double >>(
279+ const hamilt::HContainer<std::complex <double >>& hcontainer);
104280#endif
0 commit comments