@@ -16,7 +16,6 @@ limitations under the License. */
1616
1717#include < glog/logging.h>
1818
19- #include < cstdint>
2019#include < memory>
2120#include < utility>
2221
@@ -158,125 +157,134 @@ const std::shared_ptr<Generator>& GetRandomSeedGenerator(
158157// RandomGenerator.
159158std::shared_ptr<std::mt19937_64> GetCPURandomEngine (uint64_t seed) {
160159 if (seed == 0 ) {
161- VLOG (4 ) << " Use random cpu_engine from generator" ;
160+ VLOG (4 ) << " Use random engine from generator" ;
162161 return DefaultCPUGenerator ()->GetCPUEngine ();
163162 } else {
164- // NOTE(zhiqiu): creating an cpu_engine instance everytime instead of using
163+ // NOTE(zhiqiu): creating an engine instance everytime instead of using
165164 // OpDefaultCPUEngine(), this is the legacy behavior of random operators.
166165 // The benefit is that when runing PE with fixed-seed in multiple thrads,
167- // each thread has their own cpu_engine , and doesn't affect each other.
166+ // each thread has their own engine , and doesn't affect each other.
168167 //
169168 // And we need to measure the determinacy of Generator in PE.
170- auto cpu_engine = std::make_shared<std::mt19937_64>();
169+ auto engine = std::make_shared<std::mt19937_64>();
171170 static std::mutex mu_;
172171 {
173172 std::lock_guard<std::mutex> lock (mu_);
174- cpu_engine ->seed (seed);
173+ engine ->seed (seed);
175174 }
176- return cpu_engine ;
175+ return engine ;
177176 }
178177}
179178
180- inline void Generator::print_state_info () {
181- VLOG (4 ) << " Generator Random state "
182- << " device id: " << state ().device << " , seed: " << state ().seed
183- << " , offset: " << state ().offset << " , cpu_engine: " << cpu_engine ();
184- }
185-
186179Generator::Generator () {
187180 auto seed = GetRandomSeed ();
188- current_index = states_.size ();
189- states_.emplace_back (seed);
190- print_state_info ();
181+ std::seed_seq seq ({seed});
182+ auto engine = std::make_shared<std::mt19937_64>(seq);
183+ this ->state_ .cpu_engine = *engine;
184+ this ->state_ .device = -1 ;
185+ this ->state_ .current_seed = seed;
186+ this ->state_ .thread_offset = 0 ;
187+ this ->engine_ = engine;
188+ VLOG (4 ) << " initial seed: " << this ->state_ .current_seed
189+ << " , cpu engine: " << &this ->state_ .cpu_engine ;
191190}
192191
193192Generator::Generator (uint64_t seed) {
194- current_index = states_.size ();
195- states_.emplace_back (-1 , seed);
196- print_state_info ();
197- }
198-
199- Generator::Generator (uint64_t seed, int64_t device_id) {
200- current_index = states_.size ();
201- // device id first, then seed
202- states_.emplace_back (device_id, seed);
203- print_state_info ();
204- }
205-
206- phi::Generator::GeneratorState Generator::GetState () { return state ().clone (); }
207-
208- void Generator::SetState (const phi::Generator::GeneratorState& state) {
209- std::lock_guard<std::mutex> lock (mu_);
210- if (current_index < states_.size ())
211- states_[current_index] = state.clone ();
212- else
213- PADDLE_THROW (phi::errors::NotFound (" Generator index is not found" ));
214- print_state_info ();
215- }
216-
217- uint64_t Generator::GetStateIndex () { return current_index; }
218-
219- void Generator::SetStateIndex (uint64_t StateIndex) {
220- std::lock_guard<std::mutex> lock (mu_);
221- if (current_index < states_.size ())
222- current_index = StateIndex;
223- else
224- PADDLE_THROW (phi::errors::NotFound (" Generator index is not found" ));
193+ std::seed_seq seq ({seed});
194+ auto engine = std::make_shared<std::mt19937_64>(seq);
195+ this ->state_ .cpu_engine = *engine;
196+ this ->state_ .device = -1 ;
197+ this ->state_ .current_seed = seed;
198+ this ->state_ .thread_offset = 0 ;
199+ this ->engine_ = engine;
200+ VLOG (4 ) << " initial seed: " << this ->state_ .current_seed
201+ << " , cpu engine: " << &this ->state_ .cpu_engine ;
225202}
226203
227- uint64_t Generator::RegisterStateIndex (const GeneratorState& state) {
228- std::lock_guard<std::mutex> lock (mu_);
229- auto new_index = states_.size ();
230- states_.push_back (state);
231- current_index = new_index;
232- return new_index;
204+ Generator::Generator (uint64_t seed, uint64_t device_id) {
205+ std::seed_seq seq ({seed});
206+ auto engine = std::make_shared<std::mt19937_64>(seq);
207+ this ->state_ .cpu_engine = *engine;
208+ this ->state_ .device = static_cast <int64_t >(device_id);
209+ this ->state_ .current_seed = seed;
210+ this ->state_ .thread_offset = 0 ;
211+ this ->engine_ = engine;
212+ VLOG (4 ) << " initial seed: " << this ->state_ .current_seed
213+ << " , cpu engine: " << &this ->state_ .cpu_engine ;
233214}
234215
235- inline Generator::GeneratorState& Generator::state () {
236- if (current_index < states_.size ())
237- return states_[current_index];
238- else
239- PADDLE_THROW (phi::errors::NotFound (" Generator index is not found" ));
216+ phi::Generator::GeneratorState Generator::GetState () {
217+ std::lock_guard<std::mutex> lock (this ->mu_ );
218+ state_.cpu_engine = *engine_;
219+ VLOG (4 ) << " Get Random state: "
220+ << " device id: " << (uint64_t )(this ->state_ .device )
221+ << " , current_seed: " << this ->state_ .current_seed
222+ << " , thread_offset: " << this ->state_ .thread_offset
223+ << " , cpu engine: " << *(this ->engine_ );
224+ return this ->state_ ;
240225}
241226
242- inline std::shared_ptr<std::mt19937_64> Generator::cpu_engine () {
243- return state ().cpu_engine ;
227+ void Generator::SetState (const phi::Generator::GeneratorState& state) {
228+ std::lock_guard<std::mutex> lock (this ->mu_ );
229+ this ->state_ = state;
230+ this ->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine );
231+ VLOG (4 ) << " Set Random state: "
232+ << " device id: " << (uint64_t )(this ->state_ .device )
233+ << " , current_seed: " << this ->state_ .current_seed
234+ << " , thread_offset: " << this ->state_ .thread_offset
235+ << " , cpu engine: " << *(this ->engine_ );
244236}
245237
246238uint64_t Generator::GetCurrentSeed () {
247- std::lock_guard<std::mutex> lock (mu_);
248- return state (). seed ;
239+ std::lock_guard<std::mutex> lock (this -> mu_ );
240+ return this -> state_ . current_seed ;
249241}
250242
251243uint64_t Generator::Seed () {
252- std::lock_guard<std::mutex> lock (mu_);
253- uint64_t seed = GetRandomSeed ();
254- state ().reset (seed);
255- return seed;
244+ std::lock_guard<std::mutex> lock (this ->mu_ );
245+ uint64_t seed = 0 ;
246+ std::random_device de;
247+ seed = ((((uint64_t )de ()) << 32 ) + de ()) & 0x1FFFFFFFFFFFFF ;
248+ this ->state_ .current_seed = seed;
249+ std::seed_seq seq ({seed});
250+ this ->engine_ ->seed (seq);
251+
252+ return this ->state_ .current_seed ;
256253}
257254
258255void Generator::SetCurrentSeed (uint64_t seed) {
259- std::lock_guard<std::mutex> lock (mu_);
260- state ().reset (seed);
256+ std::lock_guard<std::mutex> lock (this ->mu_ );
257+ this ->state_ .current_seed = seed;
258+ this ->state_ .thread_offset = 0 ;
259+ std::seed_seq seq ({seed});
260+ this ->engine_ ->seed (seq);
261261}
262262
263263std::shared_ptr<std::mt19937_64> Generator::GetCPUEngine () {
264- return cpu_engine ();
264+ std::lock_guard<std::mutex> lock (this ->mu_ );
265+ return this ->engine_ ;
266+ }
267+
268+ void Generator::SetCPUEngine (std::shared_ptr<std::mt19937_64> engine) {
269+ std::lock_guard<std::mutex> lock (this ->mu_ );
270+ this ->engine_ = engine;
265271}
266272
267273uint64_t Generator::Random64 () {
268- std::lock_guard<std::mutex> lock (mu_);
269- auto current_engine = cpu_engine () ;
270- return (*current_engine )();
274+ std::lock_guard<std::mutex> lock (this -> mu_ );
275+ auto engine = this -> engine_ ;
276+ return (*engine )();
271277}
272278
273- std::pair<uint64_t , uint64_t > Generator::IncrementOffset (uint64_t increment) {
279+ std::pair<uint64_t , uint64_t > Generator::IncrementOffset (
280+ uint64_t increment_offset) {
274281#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
275- std::lock_guard<std::mutex> lock (mu_);
276- uint64_t offset = state ().offset ;
277- state ().offset = offset + increment;
278- print_state_info ();
279- return std::make_pair (state ().seed , offset);
282+ std::lock_guard<std::mutex> lock (this ->mu_ );
283+ uint64_t cur_offset = this ->state_ .thread_offset ;
284+ VLOG (10 ) << " cur_offset = " << cur_offset
285+ << " increment_offset = " << increment_offset;
286+ this ->state_ .thread_offset += increment_offset;
287+ return std::make_pair (this ->state_ .current_seed , cur_offset);
280288#else
281289 PADDLE_THROW (phi::errors::PermissionDenied (
282290 " Increment Offset only support in CUDA place" ));
0 commit comments