1515// / gdmepsl = d/d\epsilon_{ab} *
1616// / sum_{mu,nu} rho_{mu,nu} <chi_mu|alpha_m><alpha_m'|chi_nu>
1717template <typename TK>
18- void DeePKS_domain::cal_gdmepsl (const int lmaxd,
19- const int inlmax,
20- const int nks,
18+ void DeePKS_domain::cal_gdmepsl (const int nks,
19+ const DeePKS_Param& deepks_param,
2120 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
2221 std::vector<hamilt::HContainer<double >*> phialpha,
23- const ModuleBase::IntArray* inl_index,
2422 const hamilt::HContainer<double >* dmr,
2523 const UnitCell& ucell,
2624 const LCAO_Orbitals& orb,
@@ -33,10 +31,10 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
3331 // get DS_alpha_mu and S_nu_beta
3432
3533 int nrow = pv.nrow ;
36- const int nm = 2 * lmaxd + 1 ;
34+ const int nm = 2 * deepks_param. lmaxd + 1 ;
3735 // gdmepsl: dD/d\epsilon_{\alpha\beta}
3836 // size: [6][tot_Inl][2l+1][2l+1]
39- gdmepsl = torch::zeros ({6 , inlmax, nm, nm}, torch::dtype (torch::kFloat64 ));
37+ gdmepsl = torch::zeros ({6 , deepks_param. inlmax , nm, nm}, torch::dtype (torch::kFloat64 ));
4038 auto accessor = gdmepsl.accessor <double , 4 >();
4139
4240 DeePKS_domain::iterate_ad2 (
@@ -111,7 +109,7 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
111109 {
112110 for (int N0 = 0 ; N0 < orb.Alpha [0 ].getNchi (L0); ++N0)
113111 {
114- const int inl = inl_index[ucell.iat2it [iat]](ucell.iat2ia [iat], L0, N0);
112+ const int inl = deepks_param. inl_index [ucell.iat2it [iat]](ucell.iat2ia [iat], L0, N0);
115113 const int nm = 2 * L0 + 1 ;
116114 for (int m1 = 0 ; m1 < nm; ++m1)
117115 {
@@ -147,7 +145,7 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
147145 );
148146
149147#ifdef __MPI
150- Parallel_Reduce::reduce_all (gdmepsl.data_ptr <double >(), 6 * inlmax * nm * nm);
148+ Parallel_Reduce::reduce_all (gdmepsl.data_ptr <double >(), 6 * deepks_param. inlmax * nm * nm);
151149#endif
152150 ModuleBase::timer::tick (" DeePKS_domain" , " cal_gdmepsl" );
153151 return ;
@@ -156,9 +154,7 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
156154// calculates stress of descriptors from gradient of projected density matrices
157155// gv_epsl:d(d)/d\epsilon_{\alpha\beta}, [natom][6][des_per_atom]
158156void DeePKS_domain::cal_gvepsl (const int nat,
159- const int inlmax,
160- const int des_per_atom,
161- const std::vector<int >& inl2l,
157+ const DeePKS_Param& deepks_param,
162158 const std::vector<torch::Tensor>& gevdm,
163159 const torch::Tensor& gdmepsl,
164160 torch::Tensor& gvepsl,
@@ -172,11 +168,11 @@ void DeePKS_domain::cal_gvepsl(const int nat,
172168 if (rank == 0 )
173169 {
174170 // make gdmepsl as tensor
175- int nlmax = inlmax / nat;
171+ int nlmax = deepks_param. inlmax / nat;
176172 for (int nl = 0 ; nl < nlmax; ++nl)
177173 {
178- int nm = 2 * inl2l[nl] + 1 ;
179- torch::Tensor gdmepsl_sliced = gdmepsl.slice (1 , nl, inlmax, nlmax).slice (2 , 0 , nm, 1 ).slice (3 , 0 , nm, 1 );
174+ int nm = 2 * deepks_param. inl2l [nl] + 1 ;
175+ torch::Tensor gdmepsl_sliced = gdmepsl.slice (1 , nl, deepks_param. inlmax , nlmax).slice (2 , 0 , nm, 1 ).slice (3 , 0 , nm, 1 );
180176 gdmepsl_vector.push_back (gdmepsl_sliced);
181177 }
182178 assert (gdmepsl_vector.size () == nlmax);
@@ -197,32 +193,28 @@ void DeePKS_domain::cal_gvepsl(const int nat,
197193 gvepsl = torch::cat (gvepsl_vector, -1 );
198194 assert (gvepsl.size (0 ) == 6 );
199195 assert (gvepsl.size (1 ) == nat);
200- assert (gvepsl.size (2 ) == des_per_atom);
196+ assert (gvepsl.size (2 ) == deepks_param. des_per_atom );
201197 }
202198
203199 ModuleBase::timer::tick (" DeePKS_domain" , " cal_gvepsl" );
204200 return ;
205201}
206202
207- template void DeePKS_domain::cal_gdmepsl<double >(const int lmaxd,
208- const int inlmax,
209- const int nks,
203+ template void DeePKS_domain::cal_gdmepsl<double >(const int nks,
204+ const DeePKS_Param& deepks_param,
210205 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
211206 std::vector<hamilt::HContainer<double >*> phialpha,
212- const ModuleBase::IntArray* inl_index,
213207 const hamilt::HContainer<double >* dmr,
214208 const UnitCell& ucell,
215209 const LCAO_Orbitals& orb,
216210 const Parallel_Orbitals& pv,
217211 const Grid_Driver& GridD,
218212 torch::Tensor& gdmepsl);
219213
220- template void DeePKS_domain::cal_gdmepsl<std::complex <double >>(const int lmaxd,
221- const int inlmax,
222- const int nks,
214+ template void DeePKS_domain::cal_gdmepsl<std::complex <double >>(const int nks,
215+ const DeePKS_Param& deepks_param,
223216 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
224217 std::vector<hamilt::HContainer<double >*> phialpha,
225- const ModuleBase::IntArray* inl_index,
226218 const hamilt::HContainer<double >* dmr,
227219 const UnitCell& ucell,
228220 const LCAO_Orbitals& orb,
0 commit comments