@@ -59,12 +59,18 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
5959 if (PARAM.inp .basis_type == " lcao_in_pw" )
6060 {
6161 wanf2[0 ].create (PARAM.globalv .nlocal , npwx * PARAM.globalv .npol );
62- const size_t memory_cost = PARAM.globalv .nlocal * (PARAM.globalv .npol * npwx) * sizeof (std::complex <double >);
63- std::cout << " Memory for wanf2 (MB): " << double (memory_cost) / 1024.0 / 1024.0 << std::endl;
62+
63+ // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
64+ const size_t memory_cost = sizeof (std::complex <double >) * PARAM.globalv .nlocal * (PARAM.globalv .npol * npwx);
65+
66+ std::cout << " Memory for wanf2 (MB): " << static_cast <double >(memory_cost) / 1024.0 / 1024.0 << std::endl;
6467 ModuleBase::Memory::record (" WF::wanf2" , memory_cost);
6568 }
66- const size_t memory_cost = PARAM.inp .nbands * (PARAM.globalv .npol * npwx) * sizeof (std::complex <double >);
67- std::cout << " MEMORY FOR PSI (MB) : " << double (memory_cost) / 1024.0 / 1024.0 << std::endl;
69+
70+ // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
71+ const size_t memory_cost = sizeof (std::complex <double >) * PARAM.inp .nbands * (PARAM.globalv .npol * npwx);
72+
73+ std::cout << " MEMORY FOR PSI (MB) : " << static_cast <double >(memory_cost) / 1024.0 / 1024.0 << std::endl;
6874 ModuleBase::Memory::record (" Psi_PW" , memory_cost);
6975 }
7076 else if (PARAM.inp .basis_type != " pw" )
@@ -82,17 +88,22 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
8288 this ->wanf2 [ik].create (PARAM.globalv .nlocal , npwx * PARAM.globalv .npol );
8389 }
8490
85- const size_t memory_cost = nks2 * PARAM.globalv .nlocal * (npwx * PARAM.globalv .npol ) * sizeof (std::complex <double >);
86- std::cout << " Memory for wanf2 (MB): " << double (memory_cost) / 1024.0 / 1024.0 << std::endl;
91+ // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
92+ const size_t memory_cost = sizeof (std::complex <double >) * nks2 * PARAM.globalv .nlocal * (npwx * PARAM.globalv .npol );
93+
94+ std::cout << " Memory for wanf2 (MB): " << static_cast <double >(memory_cost) / 1024.0 / 1024.0 << std::endl;
8795 ModuleBase::Memory::record (" WF::wanf2" , memory_cost);
8896 }
8997 }
9098 else
9199 {
92100 // initial psi rather than evc
93101 psi_out = new psi::Psi<std::complex <double >>(nks2, PARAM.inp .nbands , npwx * PARAM.globalv .npol , ngk);
94- const size_t memory_cost = nks2 * PARAM.inp .nbands * (PARAM.globalv .npol * npwx) * sizeof (std::complex <double >);
95- std::cout << " MEMORY FOR PSI (MB) : " << double (memory_cost) / 1024.0 / 1024.0 << std::endl;
102+
103+ // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
104+ const size_t memory_cost = sizeof (std::complex <double >) * nks2 * PARAM.inp .nbands * (PARAM.globalv .npol * npwx);
105+
106+ std::cout << " MEMORY FOR PSI (MB) : " << static_cast <double >(memory_cost) / 1024.0 / 1024.0 << std::endl;
96107 ModuleBase::Memory::record (" Psi_PW" , memory_cost);
97108 }
98109 return psi_out;
0 commit comments