2222// for deepks_v_delta = 1
2323template <typename TK>
2424void DeePKS_domain::cal_v_delta_precalc (const int nlocal,
25- const int lmaxd,
26- const int inlmax,
2725 const int nat,
2826 const int nks,
29- const std::vector< int >& inl2l ,
27+ const DeePKS_Param& deepks_param ,
3028 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
3129 const std::vector<hamilt::HContainer<double >*> phialpha,
3230 const std::vector<torch::Tensor> gevdm,
33- const ModuleBase::IntArray* inl_index,
3431 const UnitCell& ucell,
3532 const LCAO_Orbitals& orb,
3633 const Parallel_Orbitals& pv,
@@ -47,7 +44,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
4744 typename std::conditional<std::is_same<TK, std::complex <double >>::value, c10::complex <double >, TK>::type;
4845
4946 torch::Tensor v_delta_pdm
50- = torch::zeros ({nks, nlocal, nlocal, inlmax, (2 * lmaxd + 1 ), (2 * lmaxd + 1 )}, torch::dtype (dtype));
47+ = torch::zeros ({nks, nlocal, nlocal, deepks_param. inlmax , (2 * deepks_param. lmaxd + 1 ), (2 * deepks_param. lmaxd + 1 )}, torch::dtype (dtype));
5148 auto accessor = v_delta_pdm.accessor <TK_tensor, 6 >();
5249
5350 DeePKS_domain::iterate_ad2 (
@@ -112,7 +109,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
112109 {
113110 for (int N0 = 0 ; N0 < orb.Alpha [0 ].getNchi (L0); ++N0)
114111 {
115- const int inl = inl_index[T0](I0, L0, N0);
112+ const int inl = deepks_param. inl_index [T0](I0, L0, N0);
116113 const int nm = 2 * L0 + 1 ;
117114 for (int m1 = 0 ; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
118115 {
@@ -133,20 +130,20 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
133130 }
134131 );
135132#ifdef __MPI
136- const int size = nks * nlocal * nlocal * inlmax * (2 * lmaxd + 1 ) * (2 * lmaxd + 1 );
133+ const int size = nks * nlocal * nlocal * deepks_param. inlmax * (2 * deepks_param. lmaxd + 1 ) * (2 * deepks_param. lmaxd + 1 );
137134 TK_tensor* data_tensor_ptr = v_delta_pdm.data_ptr <TK_tensor>();
138135 TK* data_ptr = reinterpret_cast <TK*>(data_tensor_ptr);
139136 Parallel_Reduce::reduce_all (data_ptr, size);
140137#endif
141138
142139 // transfer v_delta_pdm to v_delta_pdm_vector
143- int nlmax = inlmax / nat;
140+ int nlmax = deepks_param. inlmax / nat;
144141 std::vector<torch::Tensor> v_delta_pdm_vector;
145142 for (int nl = 0 ; nl < nlmax; ++nl)
146143 {
147- int nm = 2 * inl2l[nl] + 1 ;
144+ int nm = 2 * deepks_param. inl2l [nl] + 1 ;
148145 torch::Tensor v_delta_pdm_sliced
149- = v_delta_pdm.slice (3 , nl, inlmax, nlmax).slice (4 , 0 , nm, 1 ).slice (5 , 0 , nm, 1 );
146+ = v_delta_pdm.slice (3 , nl, deepks_param. inlmax , nlmax).slice (4 , 0 , nm, 1 ).slice (5 , 0 , nm, 1 );
150147 v_delta_pdm_vector.push_back (v_delta_pdm_sliced);
151148 }
152149
@@ -173,10 +170,9 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
173170// prepare_phialpha and prepare_gevdm for deepks_v_delta = 2
174171template <typename TK>
175172void DeePKS_domain::prepare_phialpha (const int nlocal,
176- const int lmaxd,
177- const int inlmax,
178173 const int nat,
179174 const int nks,
175+ const DeePKS_Param& deepks_param,
180176 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
181177 const std::vector<hamilt::HContainer<double >*> phialpha,
182178 const UnitCell& ucell,
@@ -190,8 +186,8 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
190186 constexpr torch::Dtype dtype = std::is_same<TK, double >::value ? torch::kFloat64 : torch::kComplexDouble ;
191187 using TK_tensor =
192188 typename std::conditional<std::is_same<TK, std::complex <double >>::value, c10::complex <double >, TK>::type;
193- int nlmax = inlmax / nat;
194- int mmax = 2 * lmaxd + 1 ;
189+ int nlmax = deepks_param. inlmax / nat;
190+ int mmax = 2 * deepks_param. lmaxd + 1 ;
195191 phialpha_out = torch::zeros ({nat, nlmax, nks, nlocal, mmax}, dtype);
196192 auto accessor = phialpha_out.accessor <TK_tensor, 5 >();
197193
@@ -268,16 +264,15 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
268264}
269265
270266void DeePKS_domain::prepare_gevdm (const int nat,
271- const int lmaxd,
272- const int inlmax,
267+ const DeePKS_Param& deepks_param,
273268 const LCAO_Orbitals& orb,
274269 const std::vector<torch::Tensor>& gevdm_in,
275270 torch::Tensor& gevdm_out)
276271{
277272 ModuleBase::TITLE (" DeePKS_domain" , " prepare_gevdm" );
278273 ModuleBase::timer::tick (" DeePKS_domain" , " prepare_gevdm" );
279- int nlmax = inlmax / nat;
280- int mmax = 2 * lmaxd + 1 ;
274+ int nlmax = deepks_param. inlmax / nat;
275+ int mmax = 2 * deepks_param. lmaxd + 1 ;
281276 gevdm_out = torch::zeros ({nat, nlmax, mmax, mmax, mmax}, torch::TensorOptions ().dtype (torch::kFloat64 ));
282277
283278 std::vector<torch::Tensor> gevdm_out_vector;
@@ -295,42 +290,35 @@ void DeePKS_domain::prepare_gevdm(const int nat,
295290}
296291
297292template void DeePKS_domain::cal_v_delta_precalc<double >(const int nlocal,
298- const int lmaxd,
299- const int inlmax,
300293 const int nat,
301294 const int nks,
302- const std::vector< int >& inl2l ,
295+ const DeePKS_Param& deepks_param ,
303296 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
304297 const std::vector<hamilt::HContainer<double >*> phialpha,
305298 const std::vector<torch::Tensor> gevdm,
306- const ModuleBase::IntArray* inl_index,
307299 const UnitCell& ucell,
308300 const LCAO_Orbitals& orb,
309301 const Parallel_Orbitals& pv,
310302 const Grid_Driver& GridD,
311303 torch::Tensor& v_delta_precalc);
312304template void DeePKS_domain::cal_v_delta_precalc<std::complex <double >>(
313305 const int nlocal,
314- const int lmaxd,
315- const int inlmax,
316306 const int nat,
317307 const int nks,
318- const std::vector< int >& inl2l ,
308+ const DeePKS_Param& deepks_param ,
319309 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
320310 const std::vector<hamilt::HContainer<double >*> phialpha,
321311 const std::vector<torch::Tensor> gevdm,
322- const ModuleBase::IntArray* inl_index,
323312 const UnitCell& ucell,
324313 const LCAO_Orbitals& orb,
325314 const Parallel_Orbitals& pv,
326315 const Grid_Driver& GridD,
327316 torch::Tensor& v_delta_precalc);
328317
329318template void DeePKS_domain::prepare_phialpha<double >(const int nlocal,
330- const int lmaxd,
331- const int inlmax,
332319 const int nat,
333320 const int nks,
321+ const DeePKS_Param& deepks_param,
334322 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
335323 const std::vector<hamilt::HContainer<double >*> phialpha,
336324 const UnitCell& ucell,
@@ -341,10 +329,9 @@ template void DeePKS_domain::prepare_phialpha<double>(const int nlocal,
341329
342330template void DeePKS_domain::prepare_phialpha<std::complex <double >>(
343331 const int nlocal,
344- const int lmaxd,
345- const int inlmax,
346332 const int nat,
347333 const int nks,
334+ const DeePKS_Param& deepks_param,
348335 const std::vector<ModuleBase::Vector3<double >>& kvec_d,
349336 const std::vector<hamilt::HContainer<double >*> phialpha,
350337 const UnitCell& ucell,
0 commit comments