Skip to content

Commit 20efb88

Browse files
authored
swap the sizeof() be the first multiplier to avoid overflow of int (#5561)
1 parent 8c3def4 commit 20efb88

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,18 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
5959
if (PARAM.inp.basis_type == "lcao_in_pw")
6060
{
6161
wanf2[0].create(PARAM.globalv.nlocal, npwx * PARAM.globalv.npol);
62-
const size_t memory_cost = PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx) * sizeof(std::complex<double>);
63-
std::cout << " Memory for wanf2 (MB): " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;
62+
63+
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
64+
const size_t memory_cost = sizeof(std::complex<double>) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx);
65+
66+
std::cout << " Memory for wanf2 (MB): " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
6467
ModuleBase::Memory::record("WF::wanf2", memory_cost);
6568
}
66-
const size_t memory_cost = PARAM.inp.nbands * (PARAM.globalv.npol * npwx) * sizeof(std::complex<double>);
67-
std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;
69+
70+
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
71+
const size_t memory_cost = sizeof(std::complex<double>) * PARAM.inp.nbands * (PARAM.globalv.npol * npwx);
72+
73+
std::cout << " MEMORY FOR PSI (MB) : " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
6874
ModuleBase::Memory::record("Psi_PW", memory_cost);
6975
}
7076
else if (PARAM.inp.basis_type != "pw")
@@ -82,17 +88,22 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
8288
this->wanf2[ik].create(PARAM.globalv.nlocal, npwx * PARAM.globalv.npol);
8389
}
8490

85-
const size_t memory_cost = nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol) * sizeof(std::complex<double>);
86-
std::cout << " Memory for wanf2 (MB): " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;
91+
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
92+
const size_t memory_cost = sizeof(std::complex<double>) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol);
93+
94+
std::cout << " Memory for wanf2 (MB): " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
8795
ModuleBase::Memory::record("WF::wanf2", memory_cost);
8896
}
8997
}
9098
else
9199
{
92100
// initial psi rather than evc
93101
psi_out = new psi::Psi<std::complex<double>>(nks2, PARAM.inp.nbands, npwx * PARAM.globalv.npol, ngk);
94-
const size_t memory_cost = nks2 * PARAM.inp.nbands * (PARAM.globalv.npol * npwx) * sizeof(std::complex<double>);
95-
std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;
102+
103+
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
104+
const size_t memory_cost = sizeof(std::complex<double>) * nks2 * PARAM.inp.nbands * (PARAM.globalv.npol * npwx);
105+
106+
std::cout << " MEMORY FOR PSI (MB) : " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
96107
ModuleBase::Memory::record("Psi_PW", memory_cost);
97108
}
98109
return psi_out;

source/module_psi/psi_initializer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ psi::Psi<std::complex<double>>* psi_initializer<T, Device>::allocate(const bool
7373
PARAM.inp.nbands, // because no matter what, the wavefunction finally needed has PARAM.inp.nbands bands
7474
nbasis_actual,
7575
this->pw_wfc_->npwk);
76-
double memory_cost_psi = nks_psi * PARAM.inp.nbands * this->pw_wfc_->npwk_max * PARAM.globalv.npol*
77-
sizeof(std::complex<double>);
76+
double memory_cost_psi = sizeof(std::complex<double>) * nks_psi * PARAM.inp.nbands
77+
* this->pw_wfc_->npwk_max * PARAM.globalv.npol;
7878
#ifdef __MPI
7979
// get the correct memory cost for psi by all-reduce sum
8080
Parallel_Reduce::reduce_all(memory_cost_psi);

0 commit comments

Comments
 (0)