Skip to content

Commit d2d9970

Browse files
committed
fix bug
1 parent 045b26d commit d2d9970

File tree

5 files changed

+56
-48
lines changed

5 files changed

+56
-48
lines changed

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,17 @@ void FFT_Bundle::initfft(int nx_in,
4747
assert(this->device == "cpu" || this->device == "gpu" || this->device == "dsp");
4848
assert(this->precision == "single" || this->precision == "double" || this->precision == "mixing");
4949

50-
if (this->precision == "single")
50+
if (this->precision == "single" || this->precision == "mixing")
5151
{
52+
float_flag = true;
5253
#if not defined(__ENABLE_FLOAT_FFTW)
5354
if (this->device == "cpu")
5455
{
55-
float_define = false;
56+
ModuleBase::WARNING_QUIT("FFT_Bundle", "Please enable float fftw in the cmake to use float fft");
5657
}
5758
#endif
58-
#if defined(__CUDA) || defined(__ROCM)
59-
if (this->device == "gpu")
60-
{
61-
float_flag = float_define;
62-
}
63-
#endif
64-
float_flag = float_define;
65-
double_flag = true;
6659
}
67-
if (this->precision == "double")
60+
if (this->precision == "double" || this->precision == "mixing")
6861
{
6962
double_flag = true;
7063
}

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ class FFT_Bundle
196196
private:
197197
int fft_mode = 0;
198198
bool float_flag = false;
199-
bool float_define = true;
200199
bool double_flag = false;
201200
std::shared_ptr<FFT_BASE<float>> fft_float = nullptr;
202201
std::shared_ptr<FFT_BASE<double>> fft_double = nullptr;

source/module_esolver/esolver_ks.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ ESolver_KS<T, Device>::ESolver_KS()
6666
fft_precision = "mixing";
6767
}
6868
#endif
69-
pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, PARAM.inp.precision);
69+
pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, fft_precision);
7070
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);
7171

7272
// should not use INPUT here, mohan 2024-05-12

source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ Sto_EleCond<FPTYPE, Device>::Sto_EleCond(UnitCell* p_ucell_in,
3333
this->nbands_ks = p_psi_in->get_nbands();
3434
this->nbands_sto = p_stowf_in->nchi;
3535
this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto);
36+
this->cond_dtbatch = PARAM.inp.cond_dtbatch;
3637
#ifdef __ENABLE_FLOAT_FFTW
3738
if(!std::is_same<FPTYPE, lowTYPE>::value)
3839
{
39-
this->hamilt_sto_ = new hamilt::HamiltSdftPW<std::complex<lowTYPE>, Device>(p_elec_in->pot, p_wfcpw_in, p_kv_in, p_ppcell_in, p_ucell_in, 1, &this->emin_sto_, &this->emax_sto_);
40+
this->hamilt_sto_ = new hamilt::HamiltSdftPW<std::complex<lowTYPE>, Device>(p_elec_in->pot, p_wfcpw_in, p_kv_in, p_ppcell_in, p_ucell_in, 1, &this->low_emin_, &this->low_emax_);
4041
}
4142
#endif
4243
}
@@ -149,6 +150,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
149150
psi::Psi<std::complex<lowTYPE>, Device>& leftchi,
150151
psi::Psi<std::complex<lowTYPE>, Device>& rightchi,
151152
psi::Psi<std::complex<lowTYPE>, Device>& left_hchi,
153+
psi::Psi<std::complex<lowTYPE>, Device>& right_hchi,
152154
psi::Psi<std::complex<lowTYPE>, Device>& batch_vchi,
153155
psi::Psi<std::complex<lowTYPE>, Device>& batch_vhchi,
154156
#ifdef __MPI
@@ -160,6 +162,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
160162
const int& bsize_psi,
161163
std::complex<lowTYPE>* j1,
162164
std::complex<lowTYPE>* j2,
165+
std::complex<lowTYPE>* tmpj,
163166
hamilt::Velocity<lowTYPE, Device>& velop,
164167
const int& ik,
165168
const std::complex<lowTYPE>& factor,
@@ -181,8 +184,6 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
181184
const int allbands = bandinfo[5];
182185
const int dim_jmatrix = perbands_ks * allbands_sto + perbands_sto * allbands;
183186

184-
psi::Psi<std::complex<lowTYPE>, Device> right_hchi(1, perbands_sto, npwx, npw, true);
185-
186187
hamilt->hPsi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto);
187188
hamilt->hPsi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto);
188189

@@ -261,8 +262,6 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
261262
}
262263
}
263264

264-
std::complex<lowTYPE>* tmpj = nullptr;
265-
resmem_lcomplex_op()(tmpj, allbands_sto * perbands_sto);
266265
int remain = perbands_sto;
267266
int startnb = 0;
268267
while (remain > 0)
@@ -289,7 +288,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
289288
allbands_ks,
290289
npw,
291290
&float_factor,
292-
batch_vchi.get_pointer(),
291+
&batch_vchi(idnb, 0),
293292
npwx,
294293
kspsi_all.get_pointer(),
295294
npwx,
@@ -316,7 +315,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
316315
allbands_ks,
317316
npw,
318317
&float_factor,
319-
batch_vhchi.get_pointer(),
318+
&batch_vhchi(idnb, 0),
320319
npwx,
321320
kspsi_all.get_pointer(),
322321
npwx,
@@ -342,7 +341,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
342341
allbands_sto,
343342
npw,
344343
&float_factor,
345-
batch_vchi.get_pointer(),
344+
&batch_vchi(idnb, 0),
346345
npwx,
347346
rightchi_all->get_pointer(),
348347
npwx,
@@ -357,9 +356,9 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
357356
allbands_sto,
358357
npw,
359358
&float_factor,
360-
batch_vhchi.get_pointer(),
359+
&batch_vhchi(idnb, 0),
361360
npwx,
362-
righthchi_all->get_pointer(),
361+
rightchi_all->get_pointer(),
363362
npwx,
364363
&zero,
365364
j2mat,
@@ -372,9 +371,9 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
372371
allbands_sto,
373372
npw,
374373
&float_factor,
375-
batch_vchi.get_pointer(),
374+
&batch_vchi(idnb, 0),
376375
npwx,
377-
rightchi_all->get_pointer(),
376+
righthchi_all->get_pointer(),
378377
npwx,
379378
&zero,
380379
tmpjmat,
@@ -470,7 +469,6 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
470469
Parallel_Common::reduce_data(j2, ndim * dim_jmatrix, POOL_WORLD);
471470
}
472471
#endif
473-
474472
ModuleBase::timer::tick("Sto_EleCond", "cal_jmatrix");
475473

476474
return;
@@ -550,9 +548,9 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
550548
std::complex<lowTYPE> zero = static_cast<std::complex<lowTYPE>>(0.0);
551549
std::complex<lowTYPE> imag_one = static_cast<std::complex<lowTYPE>>(ModuleBase::IMAG_UNIT);
552550
Sto_Func<lowTYPE> lowfunc;
553-
lowTYPE low_emin = static_cast<lowTYPE>(*this->stofunc.Emin);
554-
lowTYPE low_emax = static_cast<lowTYPE>(*this->stofunc.Emax);
555-
lowfunc.set_E_range(&low_emin, &low_emax);
551+
this->low_emin_ = static_cast<lowTYPE>(*this->stofunc.Emin);
552+
this->low_emax_ = static_cast<lowTYPE>(*this->stofunc.Emax);
553+
lowfunc.set_E_range(&low_emin_, &low_emax_);
556554
hamilt::HamiltSdftPW<lcomplex, Device>* p_low_hamilt = nullptr;
557555
if(hamilt_sto_ != nullptr)
558556
{
@@ -593,9 +591,9 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
593591
// std::complex<lowTYPE>* tmpcoef = batchcoef_ + (nbatch - 1) * cond_nche;
594592
// resmem_lcomplex_op()(batchmcoef_, cond_nche * nbatch);
595593
// std::complex<lowTYPE>* tmpmcoef = batchmcoef_ + (nbatch - 1) * cond_nche;
596-
batchcoef.reshape({nbatch, cond_nche});
594+
batchcoef.resize({nbatch, cond_nche});
597595
lcomplex* tmpcoef = batchcoef[nbatch-1].data<lcomplex>();
598-
batchmcoef.reshape({nbatch, cond_nche});
596+
batchmcoef.resize({nbatch, cond_nche});
599597
lcomplex* tmpmcoef = batchmcoef[nbatch-1].data<lcomplex>();
600598

601599
cpymem_lcomplex_op()(tmpcoef, chet.coef_complex, cond_nche);
@@ -635,8 +633,8 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
635633

636634
// get allbands_ks
637635
int cutib0 = 0;
638-
double emin = static_cast<double>(*this->stofunc.Emin);
639-
double emax = static_cast<double>(*this->stofunc.Emax);
636+
const double emin = static_cast<double>(*this->stofunc.Emin);
637+
const double emax = static_cast<double>(*this->stofunc.Emax);
640638
if (this->nbands_ks > 0)
641639
{
642640
double Emax_KS = std::max(emin, this->p_elec->ekb(ik, this->nbands_ks - 1));
@@ -697,7 +695,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
697695
//-----------------------------------------------------------
698696
if (GlobalV::MY_BNDGROUP == 0 && allbands_ks > 0)
699697
{
700-
jjresponse_ks(ik, nt, dt, dEcut, this->p_elec->wg, velop, ct11.data(), ct12.data(), ct22.data());
698+
this->jjresponse_ks(ik, nt, dt, dEcut, this->p_elec->wg, velop, ct11.data(), ct12.data(), ct22.data());
701699
}
702700

703701
//-----------------------------------------------------------
@@ -823,11 +821,14 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
823821
ct::Tensor j2l(t_type, device_type, {ndim, dim_jmatrix});
824822
ct::Tensor j1r(t_type, device_type, {ndim, dim_jmatrix});
825823
ct::Tensor j2r(t_type, device_type, {ndim, dim_jmatrix});
824+
ct::Tensor tmpj(t_type, device_type, {ndim, allbands_sto * perbands_sto});
826825
ModuleBase::Memory::record("SDFT::j1l", sizeof(lcomplex) * ndim * dim_jmatrix);
827826
ModuleBase::Memory::record("SDFT::j2l", sizeof(lcomplex) * ndim * dim_jmatrix);
828827
ModuleBase::Memory::record("SDFT::j1r", sizeof(lcomplex) * ndim * dim_jmatrix);
829828
ModuleBase::Memory::record("SDFT::j2r", sizeof(lcomplex) * ndim * dim_jmatrix);
829+
ModuleBase::Memory::record("SDFT::tmpj", sizeof(lcomplex) * ndim * allbands_sto * perbands_sto);
830830
psi::Psi<lcomplex, Device> tmphchil(1, perbands_sto, npwx, npw, true);
831+
psi::Psi<lcomplex, Device> tmphchir(1, perbands_sto, npwx, npw, true);
831832
ModuleBase::Memory::record("SDFT::tmphchil/r", sto_memory_cost * 2);
832833

833834
//------------------------ t loop --------------------------
@@ -978,6 +979,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
978979
exptsmfchi,
979980
exptsfchi,
980981
tmphchil,
982+
tmphchir,
981983
batch_vchi,
982984
batch_vhchi,
983985
#ifdef __MPI
@@ -989,6 +991,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
989991
bsize_psi,
990992
j1l.data<lcomplex>(),
991993
j2l.data<lcomplex>(),
994+
tmpj.data<lcomplex>(),
992995
low_velop,
993996
ik,
994997
imag_one,
@@ -1005,6 +1008,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
10051008
expmtsmfchi,
10061009
expmtsfchi,
10071010
tmphchil,
1011+
tmphchir,
10081012
batch_vchi,
10091013
batch_vhchi,
10101014
#ifdef __MPI
@@ -1016,6 +1020,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
10161020
bsize_psi,
10171021
j1r.data<lcomplex>(),
10181022
j2r.data<lcomplex>(),
1023+
tmpj.data<lcomplex>(),
10191024
low_velop,
10201025
ik,
10211026
one,
@@ -1028,21 +1033,30 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
10281033
// Im(l_ij*r_ji) = Re(-il_ij * r_ji) = Re( ((il)^+_ji)^* * r_ji)=Re(((il)^+_i)^* * r^+_i)
10291034
// ddot_real = real(A_i^* * B_i)
10301035
ModuleBase::timer::tick("Sto_EleCond", "ddot_real");
1031-
ct11[it] += static_cast<double>(
1032-
ModuleBase::GlobalFunc::ddot_real(num_per, j1l.data<lcomplex>() + st_per, j1r.data<lcomplex>() + st_per, false)
1033-
* this->p_kv->wk[ik] / 2.0);
1034-
double tmp12 = static_cast<double>(
1035-
ModuleBase::GlobalFunc::ddot_real(num_per, j1l.data<lcomplex>() + st_per, j2r.data<lcomplex>() + st_per, false));
1036-
1037-
double tmp21 = static_cast<double>(
1038-
ModuleBase::GlobalFunc::ddot_real(num_per, j2l.data<lcomplex>() + st_per, j1r.data<lcomplex>() + st_per, false));
1036+
ct11[it] += static_cast<double>(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1037+
j1l.data<lcomplex>() + st_per,
1038+
j1r.data<lcomplex>() + st_per,
1039+
false)
1040+
* this->p_kv->wk[ik] / 2.0);
1041+
double tmp12
1042+
= static_cast<double>(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1043+
j1l.data<lcomplex>() + st_per,
1044+
j2r.data<lcomplex>() + st_per,
1045+
false));
1046+
1047+
double tmp21
1048+
= static_cast<double>(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1049+
j2l.data<lcomplex>() + st_per,
1050+
j1r.data<lcomplex>() + st_per,
1051+
false));
10391052

10401053
ct12[it] -= 0.5 * (tmp12 + tmp21) * this->p_kv->wk[ik] / 2.0;
10411054

1042-
ct22[it] += static_cast<double>(
1043-
ModuleBase::GlobalFunc::ddot_real(num_per, j2l.data<lcomplex>() + st_per, j2r.data<lcomplex>() + st_per, false)
1044-
* this->p_kv->wk[ik] / 2.0);
1045-
1055+
ct22[it] += static_cast<double>(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1056+
j2l.data<lcomplex>() + st_per,
1057+
j2r.data<lcomplex>() + st_per,
1058+
false)
1059+
* this->p_kv->wk[ik] / 2.0);
10461060
ModuleBase::timer::tick("Sto_EleCond", "ddot_real");
10471061
}
10481062
std::cout << std::endl;

source/module_hamilt_pw/hamilt_stodft/sto_elecond.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ class Sto_EleCond : protected EleCond<FPTYPE, Device>
7878

7979
hamilt::HamiltSdftPW<std::complex<FPTYPE>, Device>* p_hamilt_sto = nullptr; ///< pointer to the Hamiltonian for sDFT
8080
hamilt::HamiltSdftPW<std::complex<lowTYPE>, Device>* hamilt_sto_ = nullptr; ///< pointer to the Hamiltonian for sDFT
81-
lowTYPE emin_sto_ = 0; ///< Emin of the Hamiltonian for sDFT
82-
lowTYPE emax_sto_ = 0; ///< Emax of the Hamiltonian for sDFT
81+
lowTYPE low_emin_ = 0; ///< Emin of the Hamiltonian for sDFT
82+
lowTYPE low_emax_ = 0; ///< Emax of the Hamiltonian for sDFT
8383
protected:
8484
/**
8585
* @brief calculate Jmatrix <leftv|J|rightv>
@@ -95,6 +95,7 @@ class Sto_EleCond : protected EleCond<FPTYPE, Device>
9595
psi::Psi<std::complex<lowTYPE>, Device>& leftchi,
9696
psi::Psi<std::complex<lowTYPE>, Device>& rightchi,
9797
psi::Psi<std::complex<lowTYPE>, Device>& left_hchi,
98+
psi::Psi<std::complex<lowTYPE>, Device>& right_hchi,
9899
psi::Psi<std::complex<lowTYPE>, Device>& batch_vchi,
99100
psi::Psi<std::complex<lowTYPE>, Device>& batch_vhchi,
100101
#ifdef __MPI
@@ -106,6 +107,7 @@ class Sto_EleCond : protected EleCond<FPTYPE, Device>
106107
const int& bsize_psi,
107108
std::complex<lowTYPE>* j1,
108109
std::complex<lowTYPE>* j2,
110+
std::complex<lowTYPE>* tmpj,
109111
hamilt::Velocity<lowTYPE, Device>& velop,
110112
const int& ik,
111113
const std::complex<lowTYPE>& factor,

0 commit comments

Comments
 (0)