Skip to content

Commit e032331

Browse files
authored
Revert "Enhanced RNG State Management with Index-Based Control for Graph-Safe Tensor Parallelism (#58859)" (#60148)
This reverts commit 3bcdeef.
1 parent 1115890 commit e032331

File tree

15 files changed

+158
-583
lines changed

15 files changed

+158
-583
lines changed

paddle/fluid/pybind/generator_py.cc

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,37 @@ void BindGenerator(py::module* m_ptr) {
3838
"GeneratorState")
3939
.def("current_seed",
4040
[](std::shared_ptr<phi::Generator::GeneratorState>& self) {
41-
return self->seed;
41+
return self->current_seed;
4242
})
4343
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
4444
defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU)
4545
// NOTE(shenliang03): Due to the inability to serialize mt19937_64
4646
// type, resulting in a problem with precision under the cpu.
4747
.def(py::pickle(
4848
[](const phi::Generator::GeneratorState& s) { // __getstate__
49-
return py::make_tuple(s.device, s.seed, s.offset);
49+
return py::make_tuple(s.device, s.current_seed, s.thread_offset);
5050
},
5151
[](py::tuple s) { // __setstate__
5252
if (s.size() != 3)
5353
throw std::runtime_error(
5454
"Invalid Random state. Please check the format(device, "
5555
"current_seed, thread_offset).");
5656

57-
int64_t device = s[0].cast<int64_t>();
58-
int64_t seed = s[1].cast<int64_t>();
59-
uint64_t offset = s[2].cast<uint64_t>();
60-
61-
phi::Generator::GeneratorState state(device, seed, offset);
57+
phi::Generator::GeneratorState state;
58+
state.device = s[0].cast<std::int64_t>();
59+
state.current_seed = s[1].cast<std::uint64_t>();
60+
state.thread_offset = s[2].cast<std::uint64_t>();
6261

62+
std::seed_seq seq({state.current_seed});
63+
auto engine = std::make_shared<std::mt19937_64>(seq);
64+
state.cpu_engine = *engine;
6365
return state;
6466
}))
6567
#endif
6668
.def("__str__", [](const phi::Generator::GeneratorState& self) {
6769
std::stringstream ostr;
68-
ostr << self.device << " " << self.seed << " " << self.offset << " "
69-
<< self.cpu_engine;
70+
ostr << self.device << " " << self.current_seed << " "
71+
<< self.thread_offset << " " << self.cpu_engine;
7072
return ostr.str();
7173
});
7274

@@ -76,9 +78,6 @@ void BindGenerator(py::module* m_ptr) {
7678
[](phi::Generator& self) { new (&self) phi::Generator(); })
7779
.def("get_state", &phi::Generator::GetState)
7880
.def("set_state", &phi::Generator::SetState)
79-
.def("get_state_index", &phi::Generator::GetStateIndex)
80-
.def("set_state_index", &phi::Generator::SetStateIndex)
81-
.def("register_state_index", &phi::Generator::RegisterStateIndex)
8281
.def("manual_seed",
8382
[](std::shared_ptr<phi::Generator>& self, uint64_t seed) {
8483
self->SetCurrentSeed(seed);

paddle/phi/core/generator.cc

Lines changed: 86 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616

1717
#include <glog/logging.h>
1818

19-
#include <cstdint>
2019
#include <memory>
2120
#include <utility>
2221

@@ -158,125 +157,134 @@ const std::shared_ptr<Generator>& GetRandomSeedGenerator(
158157
// RandomGenerator.
159158
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
160159
if (seed == 0) {
161-
VLOG(4) << "Use random cpu_engine from generator";
160+
VLOG(4) << "Use random engine from generator";
162161
return DefaultCPUGenerator()->GetCPUEngine();
163162
} else {
164-
// NOTE(zhiqiu): creating an cpu_engine instance everytime instead of using
163+
// NOTE(zhiqiu): creating an engine instance everytime instead of using
165164
// OpDefaultCPUEngine(), this is the legacy behavior of random operators.
166165
// The benefit is that when runing PE with fixed-seed in multiple thrads,
167-
// each thread has their own cpu_engine, and doesn't affect each other.
166+
// each thread has their own engine, and doesn't affect each other.
168167
//
169168
// And we need to measure the determinacy of Generator in PE.
170-
auto cpu_engine = std::make_shared<std::mt19937_64>();
169+
auto engine = std::make_shared<std::mt19937_64>();
171170
static std::mutex mu_;
172171
{
173172
std::lock_guard<std::mutex> lock(mu_);
174-
cpu_engine->seed(seed);
173+
engine->seed(seed);
175174
}
176-
return cpu_engine;
175+
return engine;
177176
}
178177
}
179178

180-
inline void Generator::print_state_info() {
181-
VLOG(4) << "Generator Random state "
182-
<< "device id: " << state().device << ", seed: " << state().seed
183-
<< ", offset: " << state().offset << ", cpu_engine: " << cpu_engine();
184-
}
185-
186179
Generator::Generator() {
187180
auto seed = GetRandomSeed();
188-
current_index = states_.size();
189-
states_.emplace_back(seed);
190-
print_state_info();
181+
std::seed_seq seq({seed});
182+
auto engine = std::make_shared<std::mt19937_64>(seq);
183+
this->state_.cpu_engine = *engine;
184+
this->state_.device = -1;
185+
this->state_.current_seed = seed;
186+
this->state_.thread_offset = 0;
187+
this->engine_ = engine;
188+
VLOG(4) << "initial seed: " << this->state_.current_seed
189+
<< ", cpu engine: " << &this->state_.cpu_engine;
191190
}
192191

193192
Generator::Generator(uint64_t seed) {
194-
current_index = states_.size();
195-
states_.emplace_back(-1, seed);
196-
print_state_info();
197-
}
198-
199-
Generator::Generator(uint64_t seed, int64_t device_id) {
200-
current_index = states_.size();
201-
// device id first, then seed
202-
states_.emplace_back(device_id, seed);
203-
print_state_info();
204-
}
205-
206-
phi::Generator::GeneratorState Generator::GetState() { return state().clone(); }
207-
208-
void Generator::SetState(const phi::Generator::GeneratorState& state) {
209-
std::lock_guard<std::mutex> lock(mu_);
210-
if (current_index < states_.size())
211-
states_[current_index] = state.clone();
212-
else
213-
PADDLE_THROW(phi::errors::NotFound("Generator index is not found"));
214-
print_state_info();
215-
}
216-
217-
uint64_t Generator::GetStateIndex() { return current_index; }
218-
219-
void Generator::SetStateIndex(uint64_t StateIndex) {
220-
std::lock_guard<std::mutex> lock(mu_);
221-
if (current_index < states_.size())
222-
current_index = StateIndex;
223-
else
224-
PADDLE_THROW(phi::errors::NotFound("Generator index is not found"));
193+
std::seed_seq seq({seed});
194+
auto engine = std::make_shared<std::mt19937_64>(seq);
195+
this->state_.cpu_engine = *engine;
196+
this->state_.device = -1;
197+
this->state_.current_seed = seed;
198+
this->state_.thread_offset = 0;
199+
this->engine_ = engine;
200+
VLOG(4) << "initial seed: " << this->state_.current_seed
201+
<< ", cpu engine: " << &this->state_.cpu_engine;
225202
}
226203

227-
uint64_t Generator::RegisterStateIndex(const GeneratorState& state) {
228-
std::lock_guard<std::mutex> lock(mu_);
229-
auto new_index = states_.size();
230-
states_.push_back(state);
231-
current_index = new_index;
232-
return new_index;
204+
Generator::Generator(uint64_t seed, uint64_t device_id) {
205+
std::seed_seq seq({seed});
206+
auto engine = std::make_shared<std::mt19937_64>(seq);
207+
this->state_.cpu_engine = *engine;
208+
this->state_.device = static_cast<int64_t>(device_id);
209+
this->state_.current_seed = seed;
210+
this->state_.thread_offset = 0;
211+
this->engine_ = engine;
212+
VLOG(4) << "initial seed: " << this->state_.current_seed
213+
<< ", cpu engine: " << &this->state_.cpu_engine;
233214
}
234215

235-
inline Generator::GeneratorState& Generator::state() {
236-
if (current_index < states_.size())
237-
return states_[current_index];
238-
else
239-
PADDLE_THROW(phi::errors::NotFound("Generator index is not found"));
216+
phi::Generator::GeneratorState Generator::GetState() {
217+
std::lock_guard<std::mutex> lock(this->mu_);
218+
state_.cpu_engine = *engine_;
219+
VLOG(4) << "Get Random state: "
220+
<< "device id: " << (uint64_t)(this->state_.device)
221+
<< ", current_seed: " << this->state_.current_seed
222+
<< ", thread_offset: " << this->state_.thread_offset
223+
<< ", cpu engine: " << *(this->engine_);
224+
return this->state_;
240225
}
241226

242-
inline std::shared_ptr<std::mt19937_64> Generator::cpu_engine() {
243-
return state().cpu_engine;
227+
void Generator::SetState(const phi::Generator::GeneratorState& state) {
228+
std::lock_guard<std::mutex> lock(this->mu_);
229+
this->state_ = state;
230+
this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
231+
VLOG(4) << "Set Random state: "
232+
<< "device id: " << (uint64_t)(this->state_.device)
233+
<< ", current_seed: " << this->state_.current_seed
234+
<< ", thread_offset: " << this->state_.thread_offset
235+
<< ", cpu engine: " << *(this->engine_);
244236
}
245237

246238
uint64_t Generator::GetCurrentSeed() {
247-
std::lock_guard<std::mutex> lock(mu_);
248-
return state().seed;
239+
std::lock_guard<std::mutex> lock(this->mu_);
240+
return this->state_.current_seed;
249241
}
250242

251243
uint64_t Generator::Seed() {
252-
std::lock_guard<std::mutex> lock(mu_);
253-
uint64_t seed = GetRandomSeed();
254-
state().reset(seed);
255-
return seed;
244+
std::lock_guard<std::mutex> lock(this->mu_);
245+
uint64_t seed = 0;
246+
std::random_device de;
247+
seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF;
248+
this->state_.current_seed = seed;
249+
std::seed_seq seq({seed});
250+
this->engine_->seed(seq);
251+
252+
return this->state_.current_seed;
256253
}
257254

258255
void Generator::SetCurrentSeed(uint64_t seed) {
259-
std::lock_guard<std::mutex> lock(mu_);
260-
state().reset(seed);
256+
std::lock_guard<std::mutex> lock(this->mu_);
257+
this->state_.current_seed = seed;
258+
this->state_.thread_offset = 0;
259+
std::seed_seq seq({seed});
260+
this->engine_->seed(seq);
261261
}
262262

263263
std::shared_ptr<std::mt19937_64> Generator::GetCPUEngine() {
264-
return cpu_engine();
264+
std::lock_guard<std::mutex> lock(this->mu_);
265+
return this->engine_;
266+
}
267+
268+
void Generator::SetCPUEngine(std::shared_ptr<std::mt19937_64> engine) {
269+
std::lock_guard<std::mutex> lock(this->mu_);
270+
this->engine_ = engine;
265271
}
266272

267273
uint64_t Generator::Random64() {
268-
std::lock_guard<std::mutex> lock(mu_);
269-
auto current_engine = cpu_engine();
270-
return (*current_engine)();
274+
std::lock_guard<std::mutex> lock(this->mu_);
275+
auto engine = this->engine_;
276+
return (*engine)();
271277
}
272278

273-
std::pair<uint64_t, uint64_t> Generator::IncrementOffset(uint64_t increment) {
279+
std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
280+
uint64_t increment_offset) {
274281
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
275-
std::lock_guard<std::mutex> lock(mu_);
276-
uint64_t offset = state().offset;
277-
state().offset = offset + increment;
278-
print_state_info();
279-
return std::make_pair(state().seed, offset);
282+
std::lock_guard<std::mutex> lock(this->mu_);
283+
uint64_t cur_offset = this->state_.thread_offset;
284+
VLOG(10) << "cur_offset = " << cur_offset
285+
<< " increment_offset = " << increment_offset;
286+
this->state_.thread_offset += increment_offset;
287+
return std::make_pair(this->state_.current_seed, cur_offset);
280288
#else
281289
PADDLE_THROW(phi::errors::PermissionDenied(
282290
"Increment Offset only support in CUDA place"));

0 commit comments

Comments
 (0)