Skip to content

Commit 44ff32e

Browse files
authored
Fix reinterpret bug for c10::complex<double>. (#6321)
1 parent 7620b59 commit 44ff32e

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

source/module_hamilt_lcao/module_deepks/deepks_check.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void DeePKS_domain::check_tensor(const torch::Tensor& tensor, const std::string&
4949
ofs.close();
5050
}
5151

52+
template void DeePKS_domain::check_tensor<int>(const torch::Tensor& tensor, const std::string& filename, const int rank);
5253
template void DeePKS_domain::check_tensor<double>(const torch::Tensor& tensor, const std::string& filename, const int rank);
5354
template void DeePKS_domain::check_tensor<std::complex<double>>(const torch::Tensor& tensor, const std::string& filename, const int rank);
5455

source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
4848

4949
torch::Tensor v_delta_pdm
5050
= torch::zeros({nks, nlocal, nlocal, inlmax, (2 * lmaxd + 1), (2 * lmaxd + 1)}, torch::dtype(dtype));
51-
auto accessor
52-
= v_delta_pdm.accessor<std::conditional_t<std::is_same<TK, double>::value, double, c10::complex<double>>, 6>();
51+
auto accessor = v_delta_pdm.accessor<TK_tensor, 6>();
5352

5453
DeePKS_domain::iterate_ad2(
5554
ucell,
@@ -108,7 +107,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
108107
= (kvec_d[ik] * ModuleBase::Vector3<double>(dR1 - dR2)) * ModuleBase::TWO_PI;
109108
kphase = std::complex<double>(cos(arg), sin(arg));
110109
}
111-
TK_tensor* kpase_ptr = reinterpret_cast<TK_tensor*>(&kphase);
110+
TK* kpase_ptr = reinterpret_cast<TK*>(&kphase);
112111
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0)
113112
{
114113
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
@@ -119,9 +118,10 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
119118
{
120119
for (int m2 = 0; m2 < nm; ++m2) // nm = 1 for s, 3 for p, 5 for d
121120
{
122-
TK_tensor tmp = overlap_1->get_value(iw1, ib + m1)
121+
TK tmp = overlap_1->get_value(iw1, ib + m1)
123122
* overlap_2->get_value(iw2, ib + m2) * *kpase_ptr;
124-
accessor[ik][iw1_all][iw2_all][inl][m1][m2] += tmp;
123+
TK_tensor tmp_tensor = TK_tensor(tmp);
124+
accessor[ik][iw1_all][iw2_all][inl][m1][m2] += tmp_tensor;
125125
}
126126
}
127127
ib += nm;
@@ -193,8 +193,7 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
193193
int nlmax = inlmax / nat;
194194
int mmax = 2 * lmaxd + 1;
195195
phialpha_out = torch::zeros({nat, nlmax, nks, nlocal, mmax}, dtype);
196-
auto accessor
197-
= phialpha_out.accessor<std::conditional_t<std::is_same<TK, double>::value, double, c10::complex<double>>, 5>();
196+
auto accessor = phialpha_out.accessor<TK_tensor, 5>();
198197

199198
DeePKS_domain::iterate_ad1(
200199
ucell,

0 commit comments

Comments
 (0)