|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | from __future__ import print_function
|
16 |
| -import multiprocessing |
17 | 16 | from . import core
|
18 | 17 | from . import framework
|
19 | 18 | from . import executor
|
20 |
| -from .. import compat as cpt |
21 |
| -import warnings |
| 19 | +from . import compiler |
22 | 20 | import sys
|
23 |
| -import six |
24 |
| -import os |
25 | 21 |
|
26 | 22 | __all__ = ['ParallelExecutor']
|
27 | 23 |
|
@@ -97,99 +93,27 @@ def __init__(self,
|
97 | 93 | 'Please use CompiledProgram and Executor. CompiledProgram '
|
98 | 94 | 'is a central place for optimization and Executor is the '
|
99 | 95 | 'unified executor. Example can be found in compiler.py.\n')
|
100 |
| - # step1: get places, the places are used in run too. |
101 |
| - self._places = [] |
102 |
| - if use_cuda: |
103 |
| - gpus_env = os.getenv("FLAGS_selected_gpus") |
104 |
| - if gpus_env: |
105 |
| - gpus = [int(s) for s in gpus_env.split(",")] |
106 |
| - else: |
107 |
| - gpus = [ |
108 |
| - i for i in six.moves.range(core.get_cuda_device_count()) |
109 |
| - ] |
110 |
| - self._places = [core.CUDAPlace(i) for i in gpus] |
111 |
| - else: |
112 |
| - cpu_num = int( |
113 |
| - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) |
114 |
| - self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] |
115 |
| - assert self._places, "no place for execution" |
116 | 96 |
|
117 |
| - # step2: init exec_strategy |
118 |
| - if exec_strategy is None: |
119 |
| - exec_strategy = ExecutionStrategy() |
120 |
| - exec_strategy.use_cuda = use_cuda |
121 |
| - if exec_strategy.num_threads == 0: |
122 |
| - if use_cuda: |
123 |
| - # Experiments on se-resnext shows that too many threads hurt |
124 |
| - # performance. Worth tunning for other models in the future. |
125 |
| - exec_strategy.num_threads = len(self._places) * 4 |
126 |
| - else: |
127 |
| - cpu_num = int( |
128 |
| - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) |
129 |
| - exec_strategy.num_threads = cpu_num * 2 |
130 |
| - |
131 |
| - # step3: init build_strategy |
132 | 97 | if build_strategy is None:
|
133 | 98 | build_strategy = BuildStrategy()
|
134 | 99 | build_strategy.num_trainers = num_trainers
|
135 | 100 | build_strategy.trainer_id = trainer_id
|
136 |
| - # FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, |
137 |
| - # num_trainers is 1, so the current fields of build_strategy doesn't tell if |
138 |
| - # it's distributed model. |
139 |
| - build_strategy.is_distribution = framework.is_pserver_mode( |
140 |
| - main_program) or num_trainers > 1 |
141 |
| - |
142 |
| - # step4: get main_program, scope, local_scopes |
143 |
| - main = main_program if main_program \ |
144 |
| - else framework.default_main_program() |
145 |
| - # FIXME(dzhwinter): enable_inplace should be after memory_optimize |
146 |
| - # if turn on python memory optimize, turn off the inplace_pass. |
147 |
| - if build_strategy.memory_optimize is None: |
148 |
| - build_strategy.memory_optimize = False if main._is_mem_optimized else True |
149 |
| - if build_strategy.enable_inplace is None: |
150 |
| - build_strategy.enable_inplace = False if main._is_mem_optimized else True |
151 |
| - scope = scope if scope is not None else executor.global_scope() |
152 |
| - |
153 |
| - if share_vars_from and not isinstance(share_vars_from, |
154 |
| - ParallelExecutor): |
155 |
| - raise TypeError("share_vars_from must be ParallelExecutor.") |
156 |
| - |
157 |
| - local_scopes = share_vars_from.executor.local_scopes()\ |
158 |
| - if share_vars_from else [] |
159 |
| - |
160 |
| - # step5: check trainers_endpoints, it is used for distribution. |
161 |
| - trainers_endpoints = main._trainers_endpoints |
162 |
| - if num_trainers > 1 and trainers_endpoints: |
163 |
| - assert num_trainers == len( |
164 |
| - trainers_endpoints), "num_trainers == len(endpoints)" |
165 |
| - build_strategy.trainers_endpoints = trainers_endpoints |
166 |
| - |
167 |
| - # step6: get persistable_vars, places. persistable_vars |
168 |
| - # need be broadcast to other local_scope. |
169 |
| - persistable_vars = set([ |
170 |
| - cpt.to_text(v.name) for v in [ |
171 |
| - var for var in main.list_vars() |
172 |
| - if var.persistable and var.type != core.VarDesc.VarType.RAW |
173 |
| - ] |
174 |
| - ]) |
175 |
| - |
176 |
| - def place_obj(place): |
177 |
| - p = core.Place() |
178 |
| - p.set_place(place) |
179 |
| - return p |
180 |
| - |
181 |
| - places = list(map(place_obj, self._places)) |
182 | 101 |
|
183 |
| - # step7: init ParallelExecutor |
184 |
| - # ParallelExecutor API will be deprecated, don't support parallel graph. |
185 |
| - self._graph = core.Graph(main.desc) |
| 102 | + self._places = compiler.get_available_places(use_cuda) |
| 103 | + self._scope = scope if scope is not None else executor.global_scope() |
186 | 104 |
|
187 |
| - self.executor = core.ParallelExecutor( |
188 |
| - places, persistable_vars, |
189 |
| - cpt.to_text(loss_name) if loss_name else six.u(''), scope, |
190 |
| - local_scopes, exec_strategy, build_strategy, self._graph) |
| 105 | + main_program = main_program if main_program is not None \ |
| 106 | + else framework.default_main_program() |
191 | 107 |
|
192 |
| - self.scope = scope |
| 108 | + self._compiled_program = compiler.CompiledProgram(main_program) |
| 109 | + self._compiled_program.with_data_parallel( |
| 110 | + loss_name=loss_name, |
| 111 | + build_strategy=build_strategy, |
| 112 | + exec_strategy=exec_strategy, |
| 113 | + share_vars_from=share_vars_from) |
| 114 | + self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() |
| 115 | + self._executor = executor.Executor(self._place) |
| 116 | + self._compiled_program._compile(place=self._place, scope=self._scope) |
193 | 117 |
|
194 | 118 | def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
|
195 | 119 | """
|
@@ -256,56 +180,11 @@ def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
|
256 | 180 | loss = pe.run(feed=feeder.feed(cur_batch),
|
257 | 181 | fetch_list=[avg_cost.name]))
|
258 | 182 | """
|
259 |
| - if feed is None and feed_dict is not None: |
260 |
| - feed = feed_dict |
261 |
| - print( |
262 |
| - "`feed_dict` is deprecated. Please use `feed=`", |
263 |
| - file=sys.stderr) |
264 |
| - |
265 |
| - if isinstance(feed, dict): |
266 |
| - feed_tensor_dict = dict() |
267 |
| - for feed_name in feed: |
268 |
| - feed_tensor = feed[feed_name] |
269 |
| - if not isinstance(feed_tensor, core.LoDTensor): |
270 |
| - feed_tensor = core.LoDTensor() |
271 |
| - # always set to CPU place, since the tensor need to be splitted |
272 |
| - # it is fast in CPU |
273 |
| - feed_tensor.set(feed[feed_name], core.CPUPlace()) |
274 |
| - feed_tensor_dict[feed_name] = feed_tensor |
275 |
| - |
276 |
| - self.executor.feed_and_split_tensor_into_local_scopes( |
277 |
| - feed_tensor_dict) |
278 |
| - elif isinstance(feed, list) or isinstance(feed, tuple): |
279 |
| - if len(feed) != len(self._places): |
280 |
| - raise ValueError( |
281 |
| - "Feed a list of tensor, the list should be the same size as places" |
282 |
| - ) |
283 |
| - |
284 |
| - res = list() |
285 |
| - |
286 |
| - for i, each in enumerate(feed): |
287 |
| - if not isinstance(each, dict): |
288 |
| - raise TypeError( |
289 |
| - "Each element of feed list should be a dict") |
290 |
| - res_dict = dict() |
291 |
| - for feed_name in each: |
292 |
| - tensor = each[feed_name] |
293 |
| - if not isinstance(tensor, core.LoDTensor): |
294 |
| - tmp = core.LoDTensor() |
295 |
| - tmp.set(tensor, self._places[i]) |
296 |
| - tensor = tmp |
297 |
| - res_dict[feed_name] = tensor |
298 |
| - res.append(res_dict) |
299 |
| - self.executor.feed_tensors_into_local_scopes(res) |
300 |
| - |
301 |
| - fetch_var_name = 'fetch' |
302 |
| - self.executor.run(fetch_list, fetch_var_name) |
303 |
| - arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() |
304 |
| - |
305 |
| - if return_numpy: |
306 |
| - return executor.as_numpy(arr) |
307 |
| - |
308 |
| - return [arr[i] for i in range(len(arr))] |
| 183 | + return self._executor.run(program=self._compiled_program, |
| 184 | + scope=self._scope, |
| 185 | + feed=feed, |
| 186 | + fetch_list=fetch_list, |
| 187 | + return_numpy=return_numpy) |
309 | 188 |
|
310 | 189 | @property
|
311 | 190 | def device_count(self):
|
|
0 commit comments