Skip to content

Commit 09e0d81

Browse files
committed
Refactor: upgrade stochastic wavefunction random generator to C++11 std::mt19937
1 parent e97cf59 commit 09e0d81

2 files changed

Lines changed: 24 additions & 7 deletions

File tree

source/source_pw/module_stodft/sto_wf.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <cassert>
88
#include <ctime>
99

10+
#include <random>
11+
1012
#include "source_base/global_function.h"
1113

1214
template <typename T, typename Device>
@@ -19,7 +21,7 @@ Stochastic_WF<T, Device>::~Stochastic_WF()
1921
{
2022
delete chi0_cpu;
2123
Device* ctx = {};
22-
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
24+
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
2325
{
2426
delete chi0;
2527
}
@@ -60,18 +62,22 @@ void Stochastic_WF<T, Device>::clean_chiallorder()
6062
delete[] chiallorder;
6163
chiallorder = nullptr;
6264
}
65+
6366
template <typename T, typename Device>
6467
void Stochastic_WF<T, Device>::init_sto_orbitals(const int seed_in)
6568
{
66-
const unsigned int rank_seed_offset = 10000;
69+
unsigned int final_seed;
6770
if (seed_in == 0 || seed_in == -1)
6871
{
69-
srand(static_cast<unsigned int>(time(nullptr)) + GlobalV::MY_RANK * rank_seed_offset); // GlobalV global variables are reserved
72+
final_seed = (unsigned)time(nullptr) + GlobalV::MY_RANK * 10000;
7073
}
7174
else
7275
{
73-
srand(static_cast<unsigned int>(std::abs(seed_in)) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * rank_seed_offset);
76+
final_seed = (unsigned)std::abs(seed_in) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * 10000;
7477
}
78+
79+
// initialize the random number generator with the final seed
80+
this->rng.seed(final_seed);
7581

7682
this->allocate_chi0();
7783
this->update_sto_orbitals(seed_in);
@@ -119,7 +125,7 @@ void Stochastic_WF<T, Device>::allocate_chi0()
119125

120126
// allocate chi0
121127
Device* ctx = {};
122-
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
128+
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
123129
{
124130
this->chi0 = new psi::Psi<T, Device>(nks, this->nchip_max, npwx, this->ngk, true);
125131
}
@@ -134,19 +140,26 @@ void Stochastic_WF<T, Device>::update_sto_orbitals(const int seed_in)
134140
{
135141
const int nchi = PARAM.inp.nbands_sto;
136142
this->chi0_cpu->fix_k(0);
143+
144+
// Uniform distribution to generate random phases between 0 and 2*pi
145+
std::uniform_real_distribution<double> dist_phi(0.0, 2.0 * ModuleBase::PI);
146+
// Bernoulli distribution to generate +1/sqrt(nchi) or -1/sqrt(nchi) with equal probability
147+
std::bernoulli_distribution dist_coin(0.5);
148+
137149
if (seed_in >= 0)
138150
{
139151
for (int i = 0; i < this->chi0_cpu->size(); ++i)
140152
{
141-
const double phi = 2 * ModuleBase::PI * rand() / double(RAND_MAX);
153+
const double phi = dist_phi(this->rng);
142154
this->chi0_cpu->get_pointer()[i] = std::complex<double>(cos(phi), sin(phi)) / sqrt(double(nchi));
143155
}
144156
}
145157
else
146158
{
147159
for (int i = 0; i < this->chi0_cpu->size(); ++i)
148160
{
149-
if (rand() / double(RAND_MAX) < 0.5)
161+
// use Bernoulli distribution to generate +1/sqrt(nchi) or -1/sqrt(nchi) with equal probability
162+
if (dist_coin(this->rng))
150163
{
151164
this->chi0_cpu->get_pointer()[i] = -1.0 / sqrt(double(nchi));
152165
}

source/source_pw/module_stodft/sto_wf.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class Stochastic_WF
5959
void init_com_orbitals();
6060
// sync chi0 from CPU to GPU
6161
void sync_chi0();
62+
63+
private:
64+
// random number generator
65+
std::mt19937 rng;
6266

6367
protected:
6468
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;

0 commit comments

Comments
 (0)