@@ -48,8 +48,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
4848
4949 torch::Tensor v_delta_pdm
5050 = torch::zeros ({nks, nlocal, nlocal, inlmax, (2 * lmaxd + 1 ), (2 * lmaxd + 1 )}, torch::dtype (dtype));
51- auto accessor
52- = v_delta_pdm.accessor <std::conditional_t <std::is_same<TK, double >::value, double , c10::complex <double >>, 6 >();
51+ auto accessor = v_delta_pdm.accessor <TK_tensor, 6 >();
5352
5453 DeePKS_domain::iterate_ad2 (
5554 ucell,
@@ -108,7 +107,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
108107 = (kvec_d[ik] * ModuleBase::Vector3<double >(dR1 - dR2)) * ModuleBase::TWO_PI;
109108 kphase = std::complex <double >(cos (arg), sin (arg));
110109 }
111- TK_tensor * kpase_ptr = reinterpret_cast <TK_tensor *>(&kphase);
110+ TK * kpase_ptr = reinterpret_cast <TK *>(&kphase);
112111 for (int L0 = 0 ; L0 <= orb.Alpha [0 ].getLmax (); ++L0)
113112 {
114113 for (int N0 = 0 ; N0 < orb.Alpha [0 ].getNchi (L0); ++N0)
@@ -119,9 +118,10 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
119118 {
120119 for (int m2 = 0 ; m2 < nm; ++m2) // nm = 1 for s, 3 for p, 5 for d
121120 {
122- TK_tensor tmp = overlap_1->get_value (iw1, ib + m1)
121+ TK tmp = overlap_1->get_value (iw1, ib + m1)
123122 * overlap_2->get_value (iw2, ib + m2) * *kpase_ptr;
124- accessor[ik][iw1_all][iw2_all][inl][m1][m2] += tmp;
123+ TK_tensor tmp_tensor = TK_tensor (tmp);
124+ accessor[ik][iw1_all][iw2_all][inl][m1][m2] += tmp_tensor;
125125 }
126126 }
127127 ib += nm;
@@ -193,8 +193,7 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
193193 int nlmax = inlmax / nat;
194194 int mmax = 2 * lmaxd + 1 ;
195195 phialpha_out = torch::zeros ({nat, nlmax, nks, nlocal, mmax}, dtype);
196- auto accessor
197- = phialpha_out.accessor <std::conditional_t <std::is_same<TK, double >::value, double , c10::complex <double >>, 5 >();
196+ auto accessor = phialpha_out.accessor <TK_tensor, 5 >();
198197
199198 DeePKS_domain::iterate_ad1 (
200199 ucell,
0 commit comments