Skip to content

Commit ae37f82

Browse files
author
chengduo
authored
Unified ParallelExecutor and Compiler (#15970)
* Unified ParallelExecutor and Compiler
1 parent 7235fd6 commit ae37f82

File tree

4 files changed

+65
-179
lines changed

4 files changed

+65
-179
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
15+
#include <memory>
1516
#include <string>
17+
#include <unordered_map>
1618
#include <vector>
1719
#include "paddle/fluid/framework/details/fetch_op_handle.h"
1820
#include "paddle/fluid/framework/details/multi_devices_helper.h"
@@ -55,7 +57,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
5557
std::vector<FetchOpHandle *> fetch_ops;
5658

5759
for (auto &fetch_var_name : fetch_tensors) {
58-
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
60+
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
5961
auto it = var_map.find(fetch_var_name);
6062
if (it != var_map.end()) {
6163
fetched_vars[fetch_var_name].push_back(*it->second.rbegin());

python/paddle/fluid/compiler.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import six
1818
import sys
1919
from .. import compat as cpt
20-
from . import framework
2120

2221
from . import core
2322
from . import framework
@@ -36,6 +35,30 @@ def _place_obj(place):
3635
return p
3736

3837

38+
def _is_pserver_mode(main_program):
39+
main = main_program if main_program \
40+
else default_main_program()
41+
for op in main.global_block().ops:
42+
if op.type in ["send", "recv"]:
43+
return True
44+
return False
45+
46+
47+
def get_available_places(use_cuda):
48+
if use_cuda:
49+
gpus_env = os.getenv("FLAGS_selected_gpus")
50+
if gpus_env:
51+
gpus = [int(s) for s in gpus_env.split(",")]
52+
else:
53+
gpus = [i for i in six.moves.range(core.get_cuda_device_count())]
54+
places = [core.CUDAPlace(i) for i in gpus]
55+
else:
56+
cpu_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
57+
places = [core.CPUPlace() for _ in six.moves.range(cpu_num)]
58+
assert places, "no place for execution"
59+
return places
60+
61+
3962
class CompiledProgram(object):
4063
"""
4164
Compiles to Graph for execution.
@@ -127,8 +150,7 @@ def with_data_parallel(self,
127150
self._exec_strategy = ExecutionStrategy()
128151
if self._build_strategy is None:
129152
self._build_strategy = BuildStrategy()
130-
self._build_strategy.is_distribution = framework.is_pserver_mode(
131-
self._program)
153+
self._build_strategy.is_distribution = _is_pserver_mode(self._program)
132154
return self
133155

134156
def with_inference_optimize(self, config):
@@ -153,9 +175,9 @@ def with_inference_optimize(self, config):
153175
def _with_distributed(self):
154176
raise NotImplementedError()
155177

156-
def _compile_data_parallel(self):
178+
def _compile_data_parallel(self, use_cuda=False, scope=None):
157179
if self._share_vars_from:
158-
if self._scope:
180+
if scope:
159181
sys.stderr.write("share_vars_from is set, scope is ignored.\n")
160182
if not self._share_vars_from._is_data_parallel:
161183
raise ValueError("share_vars_from is not data parallel. Cannot "
@@ -166,23 +188,11 @@ def _compile_data_parallel(self):
166188
"var to share.")
167189
self._local_scopes = self._share_vars_from._executor.local_scopes()
168190
else:
191+
assert scope is not None, ""
169192
self._local_scopes = []
170193

171-
self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace)
172-
if self._exec_strategy.use_cuda:
173-
gpus_env = os.getenv("FLAGS_selected_gpus")
174-
if gpus_env:
175-
gpus = [int(s) for s in gpus_env.split(",")]
176-
else:
177-
gpus = [
178-
i for i in six.moves.range(core.get_cuda_device_count())
179-
]
180-
self._places = [core.CUDAPlace(i) for i in gpus]
181-
else:
182-
cpu_num = int(
183-
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
184-
self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)]
185-
assert self._places, "no place for execution"
194+
self._exec_strategy.use_cuda = use_cuda
195+
self._places = get_available_places(self._exec_strategy.use_cuda)
186196

187197
if self._exec_strategy.num_threads == 0:
188198
if self._exec_strategy.use_cuda:
@@ -197,9 +207,11 @@ def _compile_data_parallel(self):
197207
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
198208
# if turn on python memory optimize, turn off the inplace_pass.
199209
if self._build_strategy.memory_optimize is None:
200-
self._build_strategy.memory_optimize = False if self._program and self._program._is_mem_optimized else True
210+
self._build_strategy.memory_optimize = False \
211+
if self._program and self._program._is_mem_optimized else True
201212
if self._build_strategy.enable_inplace is None:
202-
self._build_strategy.enable_inplace = False if self._program and self._program._is_mem_optimized else True
213+
self._build_strategy.enable_inplace = False \
214+
if self._program and self._program._is_mem_optimized else True
203215

204216
# TODO(wuyi): trainer endpoings should be passed in through
205217
# build_strategy, not program.xxx.
@@ -221,12 +233,12 @@ def _compile_data_parallel(self):
221233

222234
places = list(map(_place_obj, self._places))
223235

224-
return core.ParallelExecutor(
225-
places,
226-
set(self._persistable_vars),
227-
cpt.to_text(self._loss_name)
228-
if self._loss_name else six.u(''), self._scope, self._local_scopes,
229-
self._exec_strategy, self._build_strategy, self._graph)
236+
return core.ParallelExecutor(places,
237+
set(self._persistable_vars),
238+
cpt.to_text(self._loss_name)
239+
if self._loss_name else six.u(''), scope,
240+
self._local_scopes, self._exec_strategy,
241+
self._build_strategy, self._graph)
230242

231243
def _compile_inference(self):
232244
return core.create_paddle_predictor(self._infer_config)
@@ -253,7 +265,9 @@ def _compile(self, scope, place):
253265
self._scope = scope
254266
self._place = place
255267
if self._is_data_parallel:
256-
self._executor = self._compile_data_parallel()
268+
self._executor = self._compile_data_parallel(
269+
use_cuda=isinstance(self._place, core.CUDAPlace),
270+
scope=self._scope)
257271
elif self._is_inference:
258272
self._executor = self._compile_inference()
259273
else:

python/paddle/fluid/framework.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,6 @@ def _current_expected_place():
8787
return _imperative_current_expected_place_
8888

8989

90-
def is_pserver_mode(main_program):
91-
main = main_program if main_program \
92-
else default_main_program()
93-
for op in main.global_block().ops:
94-
if op.type in ["send", "recv"]:
95-
return True
96-
return False
97-
98-
9990
class NameScope(object):
10091
def __init__(self, name="", parent=None):
10192
self._children = dict()

python/paddle/fluid/parallel_executor.py

Lines changed: 19 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,11 @@
1313
# limitations under the License.
1414

1515
from __future__ import print_function
16-
import multiprocessing
1716
from . import core
1817
from . import framework
1918
from . import executor
20-
from .. import compat as cpt
21-
import warnings
19+
from . import compiler
2220
import sys
23-
import six
24-
import os
2521

2622
__all__ = ['ParallelExecutor']
2723

@@ -97,99 +93,27 @@ def __init__(self,
9793
'Please use CompiledProgram and Executor. CompiledProgram '
9894
'is a central place for optimization and Executor is the '
9995
'unified executor. Example can be found in compiler.py.\n')
100-
# step1: get places, the places are used in run too.
101-
self._places = []
102-
if use_cuda:
103-
gpus_env = os.getenv("FLAGS_selected_gpus")
104-
if gpus_env:
105-
gpus = [int(s) for s in gpus_env.split(",")]
106-
else:
107-
gpus = [
108-
i for i in six.moves.range(core.get_cuda_device_count())
109-
]
110-
self._places = [core.CUDAPlace(i) for i in gpus]
111-
else:
112-
cpu_num = int(
113-
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
114-
self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)]
115-
assert self._places, "no place for execution"
11696

117-
# step2: init exec_strategy
118-
if exec_strategy is None:
119-
exec_strategy = ExecutionStrategy()
120-
exec_strategy.use_cuda = use_cuda
121-
if exec_strategy.num_threads == 0:
122-
if use_cuda:
123-
# Experiments on se-resnext shows that too many threads hurt
124-
# performance. Worth tunning for other models in the future.
125-
exec_strategy.num_threads = len(self._places) * 4
126-
else:
127-
cpu_num = int(
128-
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
129-
exec_strategy.num_threads = cpu_num * 2
130-
131-
# step3: init build_strategy
13297
if build_strategy is None:
13398
build_strategy = BuildStrategy()
13499
build_strategy.num_trainers = num_trainers
135100
build_strategy.trainer_id = trainer_id
136-
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
137-
# num_trainers is 1, so the current fields of build_strategy doesn't tell if
138-
# it's distributed model.
139-
build_strategy.is_distribution = framework.is_pserver_mode(
140-
main_program) or num_trainers > 1
141-
142-
# step4: get main_program, scope, local_scopes
143-
main = main_program if main_program \
144-
else framework.default_main_program()
145-
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
146-
# if turn on python memory optimize, turn off the inplace_pass.
147-
if build_strategy.memory_optimize is None:
148-
build_strategy.memory_optimize = False if main._is_mem_optimized else True
149-
if build_strategy.enable_inplace is None:
150-
build_strategy.enable_inplace = False if main._is_mem_optimized else True
151-
scope = scope if scope is not None else executor.global_scope()
152-
153-
if share_vars_from and not isinstance(share_vars_from,
154-
ParallelExecutor):
155-
raise TypeError("share_vars_from must be ParallelExecutor.")
156-
157-
local_scopes = share_vars_from.executor.local_scopes()\
158-
if share_vars_from else []
159-
160-
# step5: check trainers_endpoints, it is used for distribution.
161-
trainers_endpoints = main._trainers_endpoints
162-
if num_trainers > 1 and trainers_endpoints:
163-
assert num_trainers == len(
164-
trainers_endpoints), "num_trainers == len(endpoints)"
165-
build_strategy.trainers_endpoints = trainers_endpoints
166-
167-
# step6: get persistable_vars, places. persistable_vars
168-
# need be broadcast to other local_scope.
169-
persistable_vars = set([
170-
cpt.to_text(v.name) for v in [
171-
var for var in main.list_vars()
172-
if var.persistable and var.type != core.VarDesc.VarType.RAW
173-
]
174-
])
175-
176-
def place_obj(place):
177-
p = core.Place()
178-
p.set_place(place)
179-
return p
180-
181-
places = list(map(place_obj, self._places))
182101

183-
# step7: init ParallelExecutor
184-
# ParallelExecutor API will be deprecated, don't support parallel graph.
185-
self._graph = core.Graph(main.desc)
102+
self._places = compiler.get_available_places(use_cuda)
103+
self._scope = scope if scope is not None else executor.global_scope()
186104

187-
self.executor = core.ParallelExecutor(
188-
places, persistable_vars,
189-
cpt.to_text(loss_name) if loss_name else six.u(''), scope,
190-
local_scopes, exec_strategy, build_strategy, self._graph)
105+
main_program = main_program if main_program is not None \
106+
else framework.default_main_program()
191107

192-
self.scope = scope
108+
self._compiled_program = compiler.CompiledProgram(main_program)
109+
self._compiled_program.with_data_parallel(
110+
loss_name=loss_name,
111+
build_strategy=build_strategy,
112+
exec_strategy=exec_strategy,
113+
share_vars_from=share_vars_from)
114+
self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
115+
self._executor = executor.Executor(self._place)
116+
self._compiled_program._compile(place=self._place, scope=self._scope)
193117

194118
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
195119
"""
@@ -256,56 +180,11 @@ def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
256180
loss = pe.run(feed=feeder.feed(cur_batch),
257181
fetch_list=[avg_cost.name]))
258182
"""
259-
if feed is None and feed_dict is not None:
260-
feed = feed_dict
261-
print(
262-
"`feed_dict` is deprecated. Please use `feed=`",
263-
file=sys.stderr)
264-
265-
if isinstance(feed, dict):
266-
feed_tensor_dict = dict()
267-
for feed_name in feed:
268-
feed_tensor = feed[feed_name]
269-
if not isinstance(feed_tensor, core.LoDTensor):
270-
feed_tensor = core.LoDTensor()
271-
# always set to CPU place, since the tensor need to be splitted
272-
# it is fast in CPU
273-
feed_tensor.set(feed[feed_name], core.CPUPlace())
274-
feed_tensor_dict[feed_name] = feed_tensor
275-
276-
self.executor.feed_and_split_tensor_into_local_scopes(
277-
feed_tensor_dict)
278-
elif isinstance(feed, list) or isinstance(feed, tuple):
279-
if len(feed) != len(self._places):
280-
raise ValueError(
281-
"Feed a list of tensor, the list should be the same size as places"
282-
)
283-
284-
res = list()
285-
286-
for i, each in enumerate(feed):
287-
if not isinstance(each, dict):
288-
raise TypeError(
289-
"Each element of feed list should be a dict")
290-
res_dict = dict()
291-
for feed_name in each:
292-
tensor = each[feed_name]
293-
if not isinstance(tensor, core.LoDTensor):
294-
tmp = core.LoDTensor()
295-
tmp.set(tensor, self._places[i])
296-
tensor = tmp
297-
res_dict[feed_name] = tensor
298-
res.append(res_dict)
299-
self.executor.feed_tensors_into_local_scopes(res)
300-
301-
fetch_var_name = 'fetch'
302-
self.executor.run(fetch_list, fetch_var_name)
303-
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
304-
305-
if return_numpy:
306-
return executor.as_numpy(arr)
307-
308-
return [arr[i] for i in range(len(arr))]
183+
return self._executor.run(program=self._compiled_program,
184+
scope=self._scope,
185+
feed=feed,
186+
fetch_list=fetch_list,
187+
return_numpy=return_numpy)
309188

310189
@property
311190
def device_count(self):

0 commit comments

Comments
 (0)