77#include < cassert>
88#include < ctime>
99
10+ #include < random>
11+
1012#include " source_base/global_function.h"
1113
1214template <typename T, typename Device>
@@ -19,7 +21,7 @@ Stochastic_WF<T, Device>::~Stochastic_WF()
1921{
2022 delete chi0_cpu;
2123 Device* ctx = {};
22- if (base_device::get_device_type (ctx) == base_device::GpuDevice)
24+ if (base_device::get_device_type<Device> (ctx) == base_device::GpuDevice)
2325 {
2426 delete chi0;
2527 }
@@ -60,18 +62,22 @@ void Stochastic_WF<T, Device>::clean_chiallorder()
6062 delete[] chiallorder;
6163 chiallorder = nullptr ;
6264}
65+
6366template <typename T, typename Device>
6467void Stochastic_WF<T, Device>::init_sto_orbitals(const int seed_in)
6568{
66- const unsigned int rank_seed_offset = 10000 ;
69+ unsigned int final_seed ;
6770 if (seed_in == 0 || seed_in == -1 )
6871 {
69- srand ( static_cast < unsigned int >( time (nullptr )) + GlobalV::MY_RANK * rank_seed_offset); // GlobalV global variables are reserved
72+ final_seed = ( unsigned ) time (nullptr ) + GlobalV::MY_RANK * 10000 ;
7073 }
7174 else
7275 {
73- srand ( static_cast < unsigned int >( std::abs (seed_in)) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * rank_seed_offset) ;
76+ final_seed = ( unsigned ) std::abs (seed_in) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * 10000 ;
7477 }
78+
79+ // initialize the random number generator with the final seed
80+ this ->rng .seed (final_seed);
7581
7682 this ->allocate_chi0 ();
7783 this ->update_sto_orbitals (seed_in);
@@ -119,7 +125,7 @@ void Stochastic_WF<T, Device>::allocate_chi0()
119125
120126 // allocate chi0
121127 Device* ctx = {};
122- if (base_device::get_device_type (ctx) == base_device::GpuDevice)
128+ if (base_device::get_device_type<Device> (ctx) == base_device::GpuDevice)
123129 {
124130 this ->chi0 = new psi::Psi<T, Device>(nks, this ->nchip_max , npwx, this ->ngk , true );
125131 }
@@ -134,19 +140,26 @@ void Stochastic_WF<T, Device>::update_sto_orbitals(const int seed_in)
134140{
135141 const int nchi = PARAM.inp .nbands_sto ;
136142 this ->chi0_cpu ->fix_k (0 );
143+
144+ // Uniform distribution to generate random phases between 0 and 2*pi
145+ std::uniform_real_distribution<double > dist_phi (0.0 , 2.0 * ModuleBase::PI);
146+ // Bernoulli distribution to generate +1/sqrt(nchi) or -1/sqrt(nchi) with equal probability
147+ std::bernoulli_distribution dist_coin (0.5 );
148+
137149 if (seed_in >= 0 )
138150 {
139151 for (int i = 0 ; i < this ->chi0_cpu ->size (); ++i)
140152 {
141- const double phi = 2 * ModuleBase::PI * rand () / double (RAND_MAX);
153+ const double phi = dist_phi ( this -> rng );
142154 this ->chi0_cpu ->get_pointer ()[i] = std::complex <double >(cos (phi), sin (phi)) / sqrt (double (nchi));
143155 }
144156 }
145157 else
146158 {
147159 for (int i = 0 ; i < this ->chi0_cpu ->size (); ++i)
148160 {
149- if (rand () / double (RAND_MAX) < 0.5 )
161+ // use Bernoulli distribution to generate +1/sqrt(nchi) or -1/sqrt(nchi) with equal probability
162+ if (dist_coin (this ->rng ))
150163 {
151164 this ->chi0_cpu ->get_pointer ()[i] = -1.0 / sqrt (double (nchi));
152165 }
0 commit comments