Skip to content

Commit 1ffbfc4

Browse files
authored
Merge pull request #16921 from xjqbest/cherry_pick_16652
Merge pull request #16652 from xjqbest/dataset_merge_develop
2 parents 975aeee + c9a3d3b commit 1ffbfc4

File tree

12 files changed

+81
-27
lines changed

12 files changed

+81
-27
lines changed

paddle/fluid/framework/data_feed.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
242242
trainer_num_ = trainer_num;
243243
}
244244

245+
template <typename T>
246+
void InMemoryDataFeed<T>::SetFleetSendBatchSize(int64_t size) {
247+
fleet_send_batch_size_ = size;
248+
}
249+
245250
template <typename T>
246251
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
247252
#ifdef _LINUX
@@ -361,8 +366,13 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
361366
VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_;
362367
auto fleet_ptr = FleetWrapper::GetInstance();
363368
std::vector<std::vector<T*>> send_vec(trainer_num_);
369+
std::vector<int> send_index(trainer_num_);
370+
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_;
364371
for (auto& vec : send_vec) {
365-
vec.reserve(fleet_send_batch_size_);
372+
vec.reserve(reserve_len);
373+
}
374+
for (int i = 0; i < trainer_num_; ++i) {
375+
send_index[i] = i;
366376
}
367377
std::vector<std::future<int32_t>> total_status;
368378
auto interval = GetMemoryDataInterval();
@@ -375,7 +385,10 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
375385
int64_t node_id = random_num % trainer_num_;
376386
send_vec[node_id].push_back(&((*memory_data_)[i]));
377387
if (i % fleet_send_batch_size_ == 0 && i != 0) {
378-
for (int j = 0; j < send_vec.size(); ++j) {
388+
// shuffle the sequence of sending to avoid network timeout error
389+
std::random_shuffle(send_index.begin(), send_index.end());
390+
for (int index = 0; index < send_index.size(); ++index) {
391+
int j = send_index[index];
379392
std::string send_str;
380393
SerializeIns(send_vec[j], &send_str);
381394
VLOG(3) << "send str_length=" << send_str.length()
@@ -388,7 +401,10 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
388401
}
389402
}
390403
}
391-
for (int j = 0; j < send_vec.size(); ++j) {
404+
// shuffle the sequence of sending to avoid network timeout error
405+
std::random_shuffle(send_index.begin(), send_index.end());
406+
for (int index = 0; index < send_index.size(); ++index) {
407+
int j = send_index[index];
392408
if (send_vec[j].size() != 0) {
393409
std::string send_str;
394410
SerializeIns(send_vec[j], &send_str);

paddle/fluid/framework/data_feed.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class DataFeed {
9494
virtual void SetThreadNum(int thread_num) {}
9595
// This function will do nothing at default
9696
virtual void SetTrainerNum(int trainer_num) {}
97+
// This function will do nothing at default
98+
virtual void SetFleetSendBatchSize(int64_t size) {}
9799
virtual void SetFileListMutex(std::mutex* mutex) {
98100
mutex_for_pick_file_ = mutex;
99101
}
@@ -212,6 +214,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
212214
virtual void SetThreadId(int thread_id);
213215
virtual void SetThreadNum(int thread_num);
214216
virtual void SetTrainerNum(int trainer_num);
217+
virtual void SetFleetSendBatchSize(int64_t size);
215218
virtual void PutInsToChannel(const std::string& ins_str);
216219
virtual void FillMemoryDataToChannel();
217220
virtual void FillChannelToMemoryData();

paddle/fluid/framework/data_set.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
6464
}
6565
}
6666

67+
// if you run distributed, and want to do global shuffle,
68+
// set this before global shuffle.
69+
// be sure you call CreateReaders before SetFleetSendBatchSize
70+
template <typename T>
71+
void DatasetImpl<T>::SetFleetSendBatchSize(int64_t size) {
72+
fleet_send_batch_size_ = size;
73+
for (auto reader : readers_) {
74+
reader->SetFleetSendBatchSize(size);
75+
}
76+
}
77+
6778
template <typename T>
6879
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
6980
const std::string& fs_ugi) {

paddle/fluid/framework/data_set.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class Dataset {
4747
virtual void SetThreadNum(int thread_num) = 0;
4848
// set workers' num
4949
virtual void SetTrainerNum(int trainer_num) = 0;
50+
// set fleet send batch size
51+
virtual void SetFleetSendBatchSize(int64_t size) = 0;
5052
// set fs name and ugi
5153
virtual void SetHdfsConfig(const std::string& fs_name,
5254
const std::string& fs_ugi) = 0;
@@ -59,6 +61,8 @@ class Dataset {
5961
virtual int GetThreadNum() = 0;
6062
// get worker num
6163
virtual int GetTrainerNum() = 0;
64+
// get fleet send batch size
65+
virtual int64_t GetFleetSendBatchSize() = 0;
6266
// get hdfs config
6367
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
6468
// get data fedd desc
@@ -98,13 +102,15 @@ class DatasetImpl : public Dataset {
98102
virtual void SetFileList(const std::vector<std::string>& filelist);
99103
virtual void SetThreadNum(int thread_num);
100104
virtual void SetTrainerNum(int trainer_num);
105+
virtual void SetFleetSendBatchSize(int64_t size);
101106
virtual void SetHdfsConfig(const std::string& fs_name,
102107
const std::string& fs_ugi);
103108
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
104109

105110
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
106111
virtual int GetThreadNum() { return thread_num_; }
107112
virtual int GetTrainerNum() { return trainer_num_; }
113+
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
108114
virtual std::pair<std::string, std::string> GetHdfsConfig() {
109115
return std::make_pair(fs_name_, fs_ugi_);
110116
}
@@ -137,6 +143,7 @@ class DatasetImpl : public Dataset {
137143
std::string fs_name_;
138144
std::string fs_ugi_;
139145
unsigned int rand_seed;
146+
int64_t fleet_send_batch_size_;
140147
};
141148

142149
// use std::vector<MultiSlotType> as data type

paddle/fluid/framework/fleet/fleet_wrapper.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ void FleetWrapper::PushDenseParamSync(
237237
std::vector<paddle::ps::Region> regions;
238238
for (auto& t : var_names) {
239239
Variable* var = scope.FindVar(t);
240+
CHECK(var != nullptr) << "var[" << t << "] not found";
240241
LoDTensor* tensor = var->GetMutable<LoDTensor>();
241242
float* g = tensor->mutable_data<float>(place);
242243
paddle::ps::Region reg(g, tensor->numel());

paddle/fluid/framework/io/shell.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
126126
}
127127

128128
close_open_fds_internal();
129-
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
129+
if (execl("/bin/bash", "bash", "-c", real_cmd, NULL) < 0) {
130130
return -1;
131131
}
132132
exit(127);

paddle/fluid/pybind/data_set_py.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,15 @@ void BindDataset(py::module* m) {
5050
.def("set_filelist", &framework::Dataset::SetFileList)
5151
.def("set_thread_num", &framework::Dataset::SetThreadNum)
5252
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
53+
.def("set_fleet_send_batch_size",
54+
&framework::Dataset::SetFleetSendBatchSize)
5355
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
5456
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
5557
.def("get_filelist", &framework::Dataset::GetFileList)
5658
.def("get_thread_num", &framework::Dataset::GetThreadNum)
5759
.def("get_trainer_num", &framework::Dataset::GetTrainerNum)
60+
.def("get_fleet_send_batch_size",
61+
&framework::Dataset::GetFleetSendBatchSize)
5862
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig)
5963
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc)
6064
.def("register_client2client_msg_handler",

python/paddle/fluid/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,13 @@ def global_shuffle(self, fleet=None):
236236
fleet: fleet singleton. Default None.
237237
"""
238238
trainer_num = 1
239+
fleet_send_batch_size = 80000
239240
if fleet is not None:
240241
fleet.fleet_instance.role_maker_._barrier_worker()
241242
trainer_num = fleet.worker_num()
242243
self.dataset.register_client2client_msg_handler()
243244
self.dataset.set_trainer_num(trainer_num)
245+
self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
244246
if fleet is not None:
245247
fleet.fleet_instance.role_maker_._barrier_worker()
246248
self.dataset.global_shuffle()

python/paddle/fluid/executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def infer_from_dataset(self,
712712
if dataset == None:
713713
raise RuntimeError("dataset is needed and should be initialized")
714714

715-
if self.place == paddle.fluid.CUDAPlace():
715+
if not isinstance(self.place, core.CPUPlace):
716716
raise RuntimeError("infer_from_dataset is verified on CPUPlace"
717717
"We will open CUDAPlace in the future")
718718

@@ -796,7 +796,7 @@ def train_from_dataset(self,
796796
if dataset == None:
797797
raise RuntimeError("dataset is need and should be initialized")
798798

799-
if self.place == paddle.fluid.CUDAPlace():
799+
if not isinstance(self.place, core.CPUPlace):
800800
raise RuntimeError("train_from_dataset is verified on CPUPlace"
801801
"We will open CUDAPlace in the future")
802802

python/paddle/fluid/incubate/fleet/parameter_server/__init__.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,25 @@ def init_pserver(self):
123123
print("You should run DistributedOptimizer.minimize() first")
124124
sys.exit(-1)
125125

126-
def init_worker(self, programs):
126+
def init_worker(self, programs, scopes=None):
127127
"""
128128
init_worker(): will be called by user. When a user knows current process is_server(), he/she
129129
should call init_worker() to initialize global information about worker and connect
130-
worker with pserver.
130+
worker with pserver. You should run startup program before init_worker.
131131
132132
Args:
133133
programs(Program|list): a Program or a list of Programs
134-
134+
scopes(Scope|list): a Scope or a list of Scopes, default None.
135135
"""
136136
if not isinstance(programs, list):
137137
programs = [programs]
138+
if scopes is None:
139+
scopes = [fluid.global_scope()] * len(programs)
140+
if len(scopes) != len(programs):
141+
print(
142+
"You should make sure len(scopes) == len(programs) or set scopes None"
143+
)
144+
sys.exit(-1)
138145
if self._opt_info:
139146
if "fleet_desc" in self._opt_info:
140147
self._dist_desc_str = text_format.MessageToString(
@@ -160,7 +167,7 @@ def init_worker(self, programs):
160167
self.role_maker_._barrier_worker()
161168
if self.role_maker_._is_first_worker():
162169
tables = self._dist_desc.trainer_param.dense_table
163-
for prog in programs:
170+
for prog, scope in zip(programs, scopes):
164171
prog_id = str(id(prog))
165172
prog_conf = self._opt_info['program_configs'][prog_id]
166173
prog_tables = {}
@@ -174,10 +181,16 @@ def init_worker(self, programs):
174181
continue
175182
var_name_list = []
176183
for i in range(0, len(table.dense_variable_name)):
177-
var_name_list.append(table.dense_variable_name[i])
178-
self._fleet_ptr.init_model(prog.desc,
179-
int(table.table_id),
180-
var_name_list)
184+
var_name = table.dense_variable_name[i]
185+
if scope.find_var(var_name) is None:
186+
print("var " + var_name +
187+
" not found in scope, " +
188+
"you should run startup program first")
189+
sys.exit(-1)
190+
var_name_list.append(var_name)
191+
self._fleet_ptr.init_model(scope,
192+
int(table.table_id),
193+
var_name_list)
181194
# barrier for init model done
182195
self.role_maker_._barrier_worker()
183196
else:

0 commit comments

Comments
 (0)