@@ -16,7 +16,6 @@ limitations under the License. */
16
16
17
17
#include < glog/logging.h>
18
18
19
- #include < cstdint>
20
19
#include < memory>
21
20
#include < utility>
22
21
@@ -158,125 +157,134 @@ const std::shared_ptr<Generator>& GetRandomSeedGenerator(
158
157
// RandomGenerator.
159
158
std::shared_ptr<std::mt19937_64> GetCPURandomEngine (uint64_t seed) {
160
159
if (seed == 0 ) {
161
- VLOG (4 ) << " Use random cpu_engine from generator" ;
160
+ VLOG (4 ) << " Use random engine from generator" ;
162
161
return DefaultCPUGenerator ()->GetCPUEngine ();
163
162
} else {
164
- // NOTE(zhiqiu): creating an cpu_engine instance everytime instead of using
163
+ // NOTE(zhiqiu): creating an engine instance everytime instead of using
165
164
// OpDefaultCPUEngine(), this is the legacy behavior of random operators.
166
165
// The benefit is that when runing PE with fixed-seed in multiple thrads,
167
- // each thread has their own cpu_engine , and doesn't affect each other.
166
+ // each thread has their own engine , and doesn't affect each other.
168
167
//
169
168
// And we need to measure the determinacy of Generator in PE.
170
- auto cpu_engine = std::make_shared<std::mt19937_64>();
169
+ auto engine = std::make_shared<std::mt19937_64>();
171
170
static std::mutex mu_;
172
171
{
173
172
std::lock_guard<std::mutex> lock (mu_);
174
- cpu_engine ->seed (seed);
173
+ engine ->seed (seed);
175
174
}
176
- return cpu_engine ;
175
+ return engine ;
177
176
}
178
177
}
179
178
180
- inline void Generator::print_state_info () {
181
- VLOG (4 ) << " Generator Random state "
182
- << " device id: " << state ().device << " , seed: " << state ().seed
183
- << " , offset: " << state ().offset << " , cpu_engine: " << cpu_engine ();
184
- }
185
-
186
179
Generator::Generator () {
187
180
auto seed = GetRandomSeed ();
188
- current_index = states_.size ();
189
- states_.emplace_back (seed);
190
- print_state_info ();
181
+ std::seed_seq seq ({seed});
182
+ auto engine = std::make_shared<std::mt19937_64>(seq);
183
+ this ->state_ .cpu_engine = *engine;
184
+ this ->state_ .device = -1 ;
185
+ this ->state_ .current_seed = seed;
186
+ this ->state_ .thread_offset = 0 ;
187
+ this ->engine_ = engine;
188
+ VLOG (4 ) << " initial seed: " << this ->state_ .current_seed
189
+ << " , cpu engine: " << &this ->state_ .cpu_engine ;
191
190
}
192
191
193
192
Generator::Generator (uint64_t seed) {
194
- current_index = states_.size ();
195
- states_.emplace_back (-1 , seed);
196
- print_state_info ();
197
- }
198
-
199
- Generator::Generator (uint64_t seed, int64_t device_id) {
200
- current_index = states_.size ();
201
- // device id first, then seed
202
- states_.emplace_back (device_id, seed);
203
- print_state_info ();
204
- }
205
-
206
- phi::Generator::GeneratorState Generator::GetState () { return state ().clone (); }
207
-
208
- void Generator::SetState (const phi::Generator::GeneratorState& state) {
209
- std::lock_guard<std::mutex> lock (mu_);
210
- if (current_index < states_.size ())
211
- states_[current_index] = state.clone ();
212
- else
213
- PADDLE_THROW (phi::errors::NotFound (" Generator index is not found" ));
214
- print_state_info ();
215
- }
216
-
217
- uint64_t Generator::GetStateIndex () { return current_index; }
218
-
219
- void Generator::SetStateIndex (uint64_t StateIndex) {
220
- std::lock_guard<std::mutex> lock (mu_);
221
- if (current_index < states_.size ())
222
- current_index = StateIndex;
223
- else
224
- PADDLE_THROW (phi::errors::NotFound (" Generator index is not found" ));
193
+ std::seed_seq seq ({seed});
194
+ auto engine = std::make_shared<std::mt19937_64>(seq);
195
+ this ->state_ .cpu_engine = *engine;
196
+ this ->state_ .device = -1 ;
197
+ this ->state_ .current_seed = seed;
198
+ this ->state_ .thread_offset = 0 ;
199
+ this ->engine_ = engine;
200
+ VLOG (4 ) << " initial seed: " << this ->state_ .current_seed
201
+ << " , cpu engine: " << &this ->state_ .cpu_engine ;
225
202
}
226
203
227
- uint64_t Generator::RegisterStateIndex (const GeneratorState& state) {
228
- std::lock_guard<std::mutex> lock (mu_);
229
- auto new_index = states_.size ();
230
- states_.push_back (state);
231
- current_index = new_index;
232
- return new_index;
204
+ Generator::Generator (uint64_t seed, uint64_t device_id) {
205
+ std::seed_seq seq ({seed});
206
+ auto engine = std::make_shared<std::mt19937_64>(seq);
207
+ this ->state_ .cpu_engine = *engine;
208
+ this ->state_ .device = static_cast <int64_t >(device_id);
209
+ this ->state_ .current_seed = seed;
210
+ this ->state_ .thread_offset = 0 ;
211
+ this ->engine_ = engine;
212
+ VLOG (4 ) << " initial seed: " << this ->state_ .current_seed
213
+ << " , cpu engine: " << &this ->state_ .cpu_engine ;
233
214
}
234
215
235
- inline Generator::GeneratorState& Generator::state () {
236
- if (current_index < states_.size ())
237
- return states_[current_index];
238
- else
239
- PADDLE_THROW (phi::errors::NotFound (" Generator index is not found" ));
216
+ phi::Generator::GeneratorState Generator::GetState () {
217
+ std::lock_guard<std::mutex> lock (this ->mu_ );
218
+ state_.cpu_engine = *engine_;
219
+ VLOG (4 ) << " Get Random state: "
220
+ << " device id: " << (uint64_t )(this ->state_ .device )
221
+ << " , current_seed: " << this ->state_ .current_seed
222
+ << " , thread_offset: " << this ->state_ .thread_offset
223
+ << " , cpu engine: " << *(this ->engine_ );
224
+ return this ->state_ ;
240
225
}
241
226
242
- inline std::shared_ptr<std::mt19937_64> Generator::cpu_engine () {
243
- return state ().cpu_engine ;
227
+ void Generator::SetState (const phi::Generator::GeneratorState& state) {
228
+ std::lock_guard<std::mutex> lock (this ->mu_ );
229
+ this ->state_ = state;
230
+ this ->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine );
231
+ VLOG (4 ) << " Set Random state: "
232
+ << " device id: " << (uint64_t )(this ->state_ .device )
233
+ << " , current_seed: " << this ->state_ .current_seed
234
+ << " , thread_offset: " << this ->state_ .thread_offset
235
+ << " , cpu engine: " << *(this ->engine_ );
244
236
}
245
237
246
238
uint64_t Generator::GetCurrentSeed () {
247
- std::lock_guard<std::mutex> lock (mu_);
248
- return state (). seed ;
239
+ std::lock_guard<std::mutex> lock (this -> mu_ );
240
+ return this -> state_ . current_seed ;
249
241
}
250
242
251
243
uint64_t Generator::Seed () {
252
- std::lock_guard<std::mutex> lock (mu_);
253
- uint64_t seed = GetRandomSeed ();
254
- state ().reset (seed);
255
- return seed;
244
+ std::lock_guard<std::mutex> lock (this ->mu_ );
245
+ uint64_t seed = 0 ;
246
+ std::random_device de;
247
+ seed = ((((uint64_t )de ()) << 32 ) + de ()) & 0x1FFFFFFFFFFFFF ;
248
+ this ->state_ .current_seed = seed;
249
+ std::seed_seq seq ({seed});
250
+ this ->engine_ ->seed (seq);
251
+
252
+ return this ->state_ .current_seed ;
256
253
}
257
254
258
255
void Generator::SetCurrentSeed (uint64_t seed) {
259
- std::lock_guard<std::mutex> lock (mu_);
260
- state ().reset (seed);
256
+ std::lock_guard<std::mutex> lock (this ->mu_ );
257
+ this ->state_ .current_seed = seed;
258
+ this ->state_ .thread_offset = 0 ;
259
+ std::seed_seq seq ({seed});
260
+ this ->engine_ ->seed (seq);
261
261
}
262
262
263
263
std::shared_ptr<std::mt19937_64> Generator::GetCPUEngine () {
264
- return cpu_engine ();
264
+ std::lock_guard<std::mutex> lock (this ->mu_ );
265
+ return this ->engine_ ;
266
+ }
267
+
268
+ void Generator::SetCPUEngine (std::shared_ptr<std::mt19937_64> engine) {
269
+ std::lock_guard<std::mutex> lock (this ->mu_ );
270
+ this ->engine_ = engine;
265
271
}
266
272
267
273
uint64_t Generator::Random64 () {
268
- std::lock_guard<std::mutex> lock (mu_);
269
- auto current_engine = cpu_engine () ;
270
- return (*current_engine )();
274
+ std::lock_guard<std::mutex> lock (this -> mu_ );
275
+ auto engine = this -> engine_ ;
276
+ return (*engine )();
271
277
}
272
278
273
- std::pair<uint64_t , uint64_t > Generator::IncrementOffset (uint64_t increment) {
279
+ std::pair<uint64_t , uint64_t > Generator::IncrementOffset (
280
+ uint64_t increment_offset) {
274
281
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
275
- std::lock_guard<std::mutex> lock (mu_);
276
- uint64_t offset = state ().offset ;
277
- state ().offset = offset + increment;
278
- print_state_info ();
279
- return std::make_pair (state ().seed , offset);
282
+ std::lock_guard<std::mutex> lock (this ->mu_ );
283
+ uint64_t cur_offset = this ->state_ .thread_offset ;
284
+ VLOG (10 ) << " cur_offset = " << cur_offset
285
+ << " increment_offset = " << increment_offset;
286
+ this ->state_ .thread_offset += increment_offset;
287
+ return std::make_pair (this ->state_ .current_seed , cur_offset);
280
288
#else
281
289
PADDLE_THROW (phi::errors::PermissionDenied (
282
290
" Increment Offset only support in CUDA place" ));
0 commit comments