Skip to content

Commit 9515ebb

Browse files
wangguibaowangguibao
authored andcommitted
AsyncExecutor (#14627)
* AsyncExecutor: C++ side * Google naming conventions * Rename MultiExecutor to AsyncExecutor * pybind with async_executor * Naming convention * remove some flags and unused code * add refactored file of async_executor and data_feed * clear async executor interface and add data feed factory * split async executor into executor_thread_worker and async_executor, refactor pybind, add datafeed and corresponding proto * Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch * refine async_executor_refactor.cc * add some files about datafeed * Revert "add some files about datafeed" This reverts commit 8ee8133. * Interface rework * add MultiSlotDataFeed * Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side * update data_feed for add MultiSlotDataFeed * update datafeed and async_executor to run bow_net demo * fix bug that finish_set_filelist failed in multithread * delete finish_binding_memory_(flag), because it can not be marked under the current interface * Fix bug * update async_executor.py for support set_use_slots * update async_executor.py for support set_use_slots and set set_dense_slots * fix bug that when the number of files is less than the number of threads, it will fetch nan * remove redundant code, and make executor exit when set a illegal queue size * add batch_size check * add MultiSlotDesc * Revert "add MultiSlotDesc" This reverts commit 2e72ebf. * add some checkpoint in DataFeedDesc * add CheckFile function in MultiSlotDataFeed * update something error info * fix deaded lock bug * Fix fetch variable * Merge error * fix code style in async_executor * using one lock blocking queue replace two lock blocking queue because of some bugs * update code style * add utest for data_feed * Fix fetch var * update utest for data_feed for multithread * update SetFileList info * fix bug in utest of data_feed * Add comments for python * Add comments for python code * Fix pybind.cc with new pybind11 version * add note for DataFeedDesc's set_use_slots function * Add save_model * update data_feed_test for multi-type * add comment for executor_thread_worker * Remove unused code * update data_feed_test for generate test data file * removed unnecessary interfaces and add comments * c++ style check * update data_feed.cc * AsyncExecutor: C++ side Google naming conventions Rename MultiExecutor to AsyncExecutor pybind with async_executor Naming convention remove some flags and unused code add refactored file of async_executor and data_feed clear async executor interface and add data feed factory split async executor into executor_thread_worker and async_executor, refactor pybind, add datafeed and corresponding proto Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch refine async_executor_refactor.cc add some files about datafeed Revert "add some files about datafeed" This reverts commit 8ee8133. add MultiSlotDataFeed Interface rework Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side update datafeed and async_executor to run bow_net demo update async_executor.py for support set_use_slots Fix bug update async_executor.py for support set_use_slots and set set_dense_slots fix bug that when the number of files is less than the number of threads, it will fetch nan remove redundant code, and make executor exit when set a illegal queue size add MultiSlotDesc Revert "add MultiSlotDesc" This reverts commit 2e72ebf. add some checkpoint in DataFeedDesc Fix fetch variable fix code style in async_executor Fix fetch var add utest for data_feed Add comments for python update utest for data_feed for multithread fix bug in utest of data_feed Add comments for python code Fix pybind.cc with new pybind11 version add note for DataFeedDesc's set_use_slots function update data_feed_test for multi-type Add save_model update data_feed_test for generate test data file removed unnecessary interfaces and add comments add comment for executor_thread_worker Remove unused code update data_feed.cc c++ style check * commit for code style * commit for code style * commit for code style * commit for code style * Comment away __init__ in async_executor.py * clang-format fix test=develop * use PADDLE_THROW instead of exit(-1); use unique_ptr to manage scope var in data_feed_test.cc * commit for update code style * commit for update code style * Add async_executor demo; Remove some methods test=develop * commit for update code style * commit for update code style * commit for update code style * update API.spec * AsyncExecutor test=develop * AsyncExecutor test=develop * AsyncExecutor test=develop * AsyncExecutor test=develop * Fix API.spec test=develop * Fix API.spec test=develop * Fix windows build error test=develop * FIx windows build error test=develop * FIx windows build error test=develop * FIx windows build error test=develop * Fix Windows Build test=develop * Fix Windows Build test=develop * Fix Windows Build test=develop * Fix code style test=develop * Fix code style test=develop * update datafeed * Fix code style test=develop * update data_feed_test for test Tensor test=develop * Fix code style test=develop * Fix windows build failure test=develop * Fix code style and windows build failure test=develop * Fix PYTHON3.5 build failure test=develop * AsyncExecutor API test=develop
1 parent 992e38a commit 9515ebb

29 files changed

+2356
-74
lines changed

paddle/fluid/API.spec

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.c
3232
paddle.fluid.BuildStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy) -> None
3333
paddle.fluid.create_lod_tensor ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None)
3434
paddle.fluid.create_random_int_lodtensor ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None)
35+
paddle.fluid.DataFeedDesc.__init__ ArgSpec(args=['self', 'proto_file'], varargs=None, keywords=None, defaults=None)
36+
paddle.fluid.DataFeedDesc.desc ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
37+
paddle.fluid.DataFeedDesc.set_batch_size ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None)
38+
paddle.fluid.DataFeedDesc.set_dense_slots ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None)
39+
paddle.fluid.DataFeedDesc.set_use_slots ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None)
40+
paddle.fluid.AsyncExecutor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=(None,))
41+
paddle.fluid.AsyncExecutor.run ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'debug'], varargs=None, keywords=None, defaults=(False,))
3542
paddle.fluid.io.save_vars ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None))
3643
paddle.fluid.io.save_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
3744
paddle.fluid.io.save_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))

paddle/fluid/framework/CMakeLists.txt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_subdirectory(ir)
3434
add_subdirectory(details)
3535
# ddim lib
3636
proto_library(framework_proto SRCS framework.proto)
37+
proto_library(async_executor_proto SRCS data_feed.proto)
3738

3839
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
3940
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
@@ -135,7 +136,7 @@ endif(NOT WIN32)
135136
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
136137
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
137138

138-
py_proto_compile(framework_py_proto SRCS framework.proto)
139+
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
139140
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
140141
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
141142
add_dependencies(framework_py_proto framework_py_proto_init)
@@ -157,27 +158,31 @@ endif(NOT WIN32)
157158
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
158159

159160
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
161+
cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
160162

161-
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
163+
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
162164

163165
if(WITH_DISTRIBUTE)
164-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
166+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper)
165167
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
166168
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
167169
else()
168170
if(NOT WIN32)
169-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
171+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper)
170172
else(NOT WIN32)
171-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
173+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
172174
endif(NOT WIN32)
173175
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
174176
endif()
175177

176178
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
177179
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
178180
graph build_strategy
179-
fast_threaded_ssa_graph_executor)
181+
fast_threaded_ssa_graph_executor variable_helper)
180182

183+
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper)
184+
185+
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
181186
cc_library(prune SRCS prune.cc DEPS framework_proto)
182187
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
183188
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/async_executor.h"
16+
#include "google/protobuf/io/zero_copy_stream_impl.h"
17+
#include "google/protobuf/message.h"
18+
#include "google/protobuf/text_format.h"
19+
20+
#include "gflags/gflags.h"
21+
#include "paddle/fluid/framework/data_feed_factory.h"
22+
#include "paddle/fluid/framework/executor_thread_worker.h"
23+
#include "paddle/fluid/framework/feed_fetch_method.h"
24+
#include "paddle/fluid/framework/feed_fetch_type.h"
25+
#include "paddle/fluid/framework/lod_rank_table.h"
26+
#include "paddle/fluid/framework/lod_tensor_array.h"
27+
#include "paddle/fluid/framework/op_registry.h"
28+
#include "paddle/fluid/framework/reader.h"
29+
#include "paddle/fluid/inference/io.h"
30+
#include "paddle/fluid/platform/place.h"
31+
#include "paddle/fluid/pybind/pybind.h"
32+
33+
namespace paddle {
34+
namespace framework {
35+
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
36+
: root_scope_(scope), place_(place) {}
37+
38+
void AsyncExecutor::CreateThreads(
39+
ExecutorThreadWorker* worker, const ProgramDesc& main_program,
40+
const std::shared_ptr<DataFeed>& reader,
41+
const std::vector<std::string>& fetch_var_names, Scope* root_scope,
42+
const int thread_index, const bool debug) {
43+
worker->SetThreadId(thread_index);
44+
worker->SetDebug(debug);
45+
worker->SetRootScope(root_scope);
46+
worker->CreateThreadResource(main_program, place_);
47+
worker->SetDataFeed(reader);
48+
worker->SetFetchVarNames(fetch_var_names);
49+
worker->BindingDataFeedMemory();
50+
}
51+
52+
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
53+
const int thread_num, const DataFeedDesc& data_feed_desc,
54+
const std::vector<std::string>& filelist) {
55+
readers.resize(thread_num);
56+
for (size_t i = 0; i < readers.size(); ++i) {
57+
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
58+
readers[i]->Init(data_feed_desc); // set batch_size and queue_size here
59+
}
60+
readers[0]->SetFileList(filelist);
61+
}
62+
63+
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
64+
const std::string& data_feed_desc_str,
65+
const std::vector<std::string>& filelist,
66+
const int thread_num,
67+
const std::vector<std::string>& fetch_var_names,
68+
const bool debug) {
69+
std::vector<std::thread> threads;
70+
71+
auto& block = main_program.Block(0);
72+
for (auto var_name : fetch_var_names) {
73+
auto var_desc = block.FindVar(var_name);
74+
auto shapes = var_desc->GetShape();
75+
PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1,
76+
"var %s: Fetched var has wrong shape, "
77+
"only variables with the last dimension size 1 supported",
78+
var_name);
79+
}
80+
81+
DataFeedDesc data_feed_desc;
82+
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
83+
&data_feed_desc);
84+
85+
int actual_thread_num = thread_num;
86+
int file_cnt = filelist.size();
87+
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
88+
89+
if (actual_thread_num > file_cnt) {
90+
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
91+
<< ". Changing thread_num = " << file_cnt;
92+
actual_thread_num = file_cnt;
93+
}
94+
95+
/*
96+
readerDesc: protobuf description for reader initlization
97+
argument: class_name, batch_size, use_slot, queue_size, buffer_size,
98+
padding_index
99+
100+
reader:
101+
1) each thread has a reader, reader will read input data and
102+
put it into input queue
103+
2) each reader has a Next() iterface, that can fetch an instance
104+
from the input queue
105+
*/
106+
// todo: should be factory method for creating datafeed
107+
std::vector<std::shared_ptr<DataFeed>> readers;
108+
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
109+
110+
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
111+
workers.resize(actual_thread_num);
112+
for (auto& worker : workers) {
113+
worker.reset(new ExecutorThreadWorker);
114+
}
115+
116+
// prepare thread resource here
117+
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
118+
CreateThreads(workers[thidx].get(), main_program, readers[thidx],
119+
fetch_var_names, root_scope_, thidx, debug);
120+
}
121+
122+
// start executing ops in multiple threads
123+
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
124+
threads.push_back(
125+
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
126+
}
127+
128+
for (auto& th : threads) {
129+
th.join();
130+
}
131+
132+
root_scope_->DropKids();
133+
134+
return;
135+
}
136+
137+
} // einit_modelnd namespace framework
138+
} // end namespace paddle
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <map>
18+
#include <memory>
19+
#include <mutex> // NOLINT
20+
#include <set>
21+
#include <string>
22+
#include <thread> // NOLINT
23+
#include <typeinfo>
24+
#include <vector>
25+
#include "paddle/fluid/framework/data_feed.pb.h"
26+
#include "paddle/fluid/framework/executor.h"
27+
#include "paddle/fluid/framework/executor_thread_worker.h"
28+
#include "paddle/fluid/framework/program_desc.h"
29+
#include "paddle/fluid/framework/scope.h"
30+
31+
namespace paddle {
32+
namespace framework {
33+
class AsyncExecutor {
34+
public:
35+
AsyncExecutor(Scope* scope, const platform::Place& place);
36+
virtual ~AsyncExecutor() {}
37+
void RunFromFile(const ProgramDesc& main_program,
38+
const std::string& data_feed_desc_str,
39+
const std::vector<std::string>& filelist,
40+
const int thread_num,
41+
const std::vector<std::string>& fetch_names,
42+
const bool debug = false);
43+
44+
private:
45+
void CreateThreads(ExecutorThreadWorker* worker,
46+
const ProgramDesc& main_program,
47+
const std::shared_ptr<DataFeed>& reader,
48+
const std::vector<std::string>& fetch_var_names,
49+
Scope* root_scope, const int thread_index,
50+
const bool debug);
51+
52+
public:
53+
Scope* root_scope_;
54+
platform::Place place_;
55+
};
56+
57+
} // namespace framework
58+
} // namespace paddle

0 commit comments

Comments
 (0)