Skip to content

Commit 143023b

Browse files
authored
1.7 bug fix (#22862)
* test=develop, bug fix for trainer_factory (#22751) * test=develop, optimize distributedstrategy (#22677)
1 parent b7937d2 commit 143023b

File tree

3 files changed

+201
-67
lines changed

3 files changed

+201
-67
lines changed

python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py

Lines changed: 173 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,87 @@
1919

2020
import os
2121
import paddle.fluid as fluid
22-
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig
22+
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig, DistributedMode
2323

2424

2525
class TrainerRuntimeConfig(object):
2626
def __init__(self):
27+
self.mode = None
28+
num_threads = os.getenv("CPU_NUM", "1")
29+
2730
self.runtime_configs = {}
31+
self.runtime_configs['communicator_max_merge_var_num'] = os.getenv(
32+
"FLAGS_communicator_max_merge_var_num", num_threads)
33+
self.runtime_configs['communicator_send_queue_size'] = os.getenv(
34+
"FLAGS_communicator_send_queue_size", num_threads)
35+
self.runtime_configs[
36+
'communicator_independent_recv_thread'] = os.getenv(
37+
"FLAGS_communicator_independent_recv_thread", "1")
38+
self.runtime_configs[
39+
'communicator_min_send_grad_num_before_recv'] = os.getenv(
40+
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
41+
self.runtime_configs['communicator_thread_pool_size'] = os.getenv(
42+
"FLAGS_communicator_thread_pool_size", "5")
43+
self.runtime_configs['communicator_send_wait_times'] = os.getenv(
44+
"FLAGS_communicator_send_wait_times", "5")
45+
self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv(
46+
"FLAGS_communicator_is_sgd_optimizer", "1")
47+
2848
# not used
2949
self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline",
3050
"180000")
3151
self.runtime_configs['rpc_retry_times'] = os.getenv(
3252
"FLAGS_rpc_retry_times", "3")
3353

3454
def get_communicator_flags(self):
35-
return self.runtime_configs
36-
37-
def __repr__(self):
55+
need_keys = []
56+
num_threads = os.getenv("CPU_NUM", "1")
57+
mode_str = ""
58+
if self.mode is None or self.mode == DistributedMode.ASYNC:
59+
need_keys = self.runtime_configs.keys()
60+
mode_str = "async"
61+
elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
62+
mode_str = "sync or half_async"
63+
need_keys = [
64+
'communicator_max_merge_var_num',
65+
'communicator_send_wait_times', 'communicator_thread_pool_size',
66+
'communicator_send_queue_size'
67+
]
68+
elif self.mode == DistributedMode.GEO:
69+
mode_str = "GEO"
70+
need_keys = [
71+
'communicator_thread_pool_size', 'communicator_send_wait_times'
72+
]
73+
else:
74+
raise ValueError("Unsupported Mode")
75+
76+
if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
77+
max_merge_var_num = self.runtime_configs[
78+
'communicator_max_merge_var_num']
79+
send_queue_size = self.runtime_configs[
80+
'communicator_send_queue_size']
81+
if max_merge_var_num != num_threads:
82+
print('WARNING: In {} mode, communicator_max_merge_var_num '
83+
'must be equal to CPU_NUM. But received, '
84+
'communicator_max_merge_var_num = {}, CPU_NUM = '
85+
'{}. communicator_max_merge_var_num will be fored to {}.'
86+
.format(mode_str, max_merge_var_num, num_threads,
87+
num_threads))
88+
self.runtime_configs[
89+
'communicator_max_merge_var_num'] = num_threads
90+
if send_queue_size != num_threads:
91+
print('WARNING: In {} mode, communicator_send_queue_size '
92+
'must be equal to CPU_NUM. But received, '
93+
'communicator_send_queue_size = {}, CPU_NUM = '
94+
'{}. communicator_send_queue_size will be fored to {}.'
95+
.format(mode_str, send_queue_size, num_threads,
96+
num_threads))
97+
self.runtime_configs[
98+
'communicator_send_queue_size'] = num_threads
99+
100+
return dict((key, str(self.runtime_configs[key])) for key in need_keys)
101+
102+
def display(self, configs):
38103
raw0, raw1, length = 45, 5, 50
39104
h_format = "{:^45s}{:<5s}\n"
40105
l_format = "{:<45s}{:<5s}\n"
@@ -47,14 +112,17 @@ def __repr__(self):
47112
draws += h_format.format("TrainerRuntimeConfig Overview", "Value")
48113
draws += line + "\n"
49114

50-
for k, v in self.get_communicator_flags().items():
115+
for k, v in configs.items():
51116
draws += l_format.format(k, v)
52117

53118
draws += border
54119

55120
_str = "\n{}\n".format(draws)
56121
return _str
57122

123+
def __repr__(self):
124+
return self.display(self.get_communicator_flags())
125+
58126

59127
class DistributedStrategy(object):
60128
def __init__(self):
@@ -105,6 +173,12 @@ def set_program_config(self, config):
105173
raise TypeError(
106174
"program_config only accept input type: dict or DistributeTranspilerConfig"
107175
)
176+
self.check_program_config()
177+
178+
def check_program_config(self):
179+
raise NotImplementedError(
180+
"check_program_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
181+
)
108182

109183
def get_trainer_runtime_config(self):
110184
return self._trainer_runtime_config
@@ -123,6 +197,12 @@ def set_trainer_runtime_config(self, config):
123197
raise TypeError(
124198
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
125199
)
200+
self.check_trainer_runtime_config()
201+
202+
def check_trainer_runtime_config(self):
203+
raise NotImplementedError(
204+
"check_trainer_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
205+
)
126206

127207
def get_server_runtime_config(self):
128208
return self._server_runtime_config
@@ -141,6 +221,12 @@ def set_server_runtime_config(self, config):
141221
raise TypeError(
142222
"server_runtime_config only accept input type: dict or ServerRuntimeConfig"
143223
)
224+
self.check_server_runtime_config()
225+
226+
def check_server_runtime_config(self):
227+
raise NotImplementedError(
228+
"check_server_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
229+
)
144230

145231
def get_execute_strategy(self):
146232
return self._execute_strategy
@@ -159,6 +245,12 @@ def set_execute_strategy(self, config):
159245
raise TypeError(
160246
"execute_strategy only accept input type: dict or ExecutionStrategy"
161247
)
248+
self.check_execute_strategy()
249+
250+
def check_execute_strategy(self):
251+
raise NotImplementedError(
252+
"check_execute_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
253+
)
162254

163255
def get_build_strategy(self):
164256
return self._build_strategy
@@ -176,106 +268,121 @@ def set_build_strategy(self, config):
176268
else:
177269
raise TypeError(
178270
"build_strategy only accept input type: dict or BuildStrategy")
271+
self.check_build_strategy()
272+
273+
def check_build_strategy(self):
274+
raise NotImplementedError(
275+
"check_build_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
276+
)
179277

180278

181279
class SyncStrategy(DistributedStrategy):
182280
def __init__(self):
183281
super(SyncStrategy, self).__init__()
282+
self.check_program_config()
283+
self.check_trainer_runtime_config()
284+
self.check_server_runtime_config()
285+
self.check_build_strategy()
286+
self.check_execute_strategy()
287+
288+
def check_trainer_runtime_config(self):
289+
self._trainer_runtime_config.mode = DistributedMode.SYNC
290+
291+
def check_program_config(self):
184292
self._program_config.sync_mode = False
185293
self._program_config.runtime_split_send_recv = True
186-
self._build_strategy.async_mode = True
187294
self._program_config.half_async = True
188295
self._program_config.completely_not_async = True
189-
self._execute_strategy.use_thread_barrier = True
190296

191-
num_threads = os.getenv("CPU_NUM", "1")
297+
def check_server_runtime_config(self):
298+
pass
192299

193-
self._trainer_runtime_config.runtime_configs[
194-
'communicator_max_merge_var_num'] = os.getenv(
195-
"FLAGS_communicator_max_merge_var_num", num_threads)
196-
self._trainer_runtime_config.runtime_configs[
197-
'communicator_send_wait_times'] = os.getenv(
198-
"FLAGS_communicator_send_wait_times", "5")
199-
self._trainer_runtime_config.runtime_configs[
200-
'communicator_thread_pool_size'] = os.getenv(
201-
"FLAGS_communicator_thread_pool_size", "10")
202-
self._trainer_runtime_config.runtime_configs[
203-
'communicator_send_queue_size'] = os.getenv(
204-
"FLAGS_communicator_send_queue_size", num_threads)
300+
def check_execute_strategy(self):
301+
self._execute_strategy.use_thread_barrier = True
302+
303+
def check_build_strategy(self):
304+
self._build_strategy.async_mode = True
205305

206306

207307
class AsyncStrategy(DistributedStrategy):
208308
def __init__(self):
209309
super(AsyncStrategy, self).__init__()
310+
self.check_program_config()
311+
self.check_trainer_runtime_config()
312+
self.check_server_runtime_config()
313+
self.check_build_strategy()
314+
self.check_execute_strategy()
315+
316+
def check_trainer_runtime_config(self):
317+
self._trainer_runtime_config.mode = DistributedMode.ASYNC
318+
319+
def check_program_config(self):
210320
self._program_config.sync_mode = False
211321
self._program_config.runtime_split_send_recv = True
212-
self._build_strategy.async_mode = True
213322

214-
num_threads = os.getenv("CPU_NUM", "1")
323+
def check_server_runtime_config(self):
324+
pass
215325

216-
self._trainer_runtime_config.runtime_configs[
217-
'communicator_max_merge_var_num'] = os.getenv(
218-
"FLAGS_communicator_max_merge_var_num", num_threads)
219-
self._trainer_runtime_config.runtime_configs[
220-
'communicator_independent_recv_thread'] = os.getenv(
221-
"FLAGS_communicator_independent_recv_thread", "0")
222-
self._trainer_runtime_config.runtime_configs[
223-
'communicator_min_send_grad_num_before_recv'] = os.getenv(
224-
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
225-
self._trainer_runtime_config.runtime_configs[
226-
'communicator_thread_pool_size'] = os.getenv(
227-
"FLAGS_communicator_thread_pool_size", "10")
228-
self._trainer_runtime_config.runtime_configs[
229-
'communicator_send_wait_times'] = os.getenv(
230-
"FLAGS_communicator_send_wait_times", "5")
231-
self._trainer_runtime_config.runtime_configs[
232-
'communicator_is_sgd_optimizer'] = os.getenv(
233-
"FLAGS_communicator_is_sgd_optimizer", "1")
234-
self._trainer_runtime_config.runtime_configs[
235-
'communicator_send_queue_size'] = os.getenv(
236-
"FLAGS_communicator_send_queue_size", num_threads)
326+
def check_execute_strategy(self):
327+
pass
328+
329+
def check_build_strategy(self):
330+
self._build_strategy.async_mode = True
237331

238332

239333
class HalfAsyncStrategy(DistributedStrategy):
240334
def __init__(self):
241335
super(HalfAsyncStrategy, self).__init__()
336+
self.check_program_config()
337+
self.check_trainer_runtime_config()
338+
self.check_server_runtime_config()
339+
self.check_build_strategy()
340+
self.check_execute_strategy()
341+
342+
def check_trainer_runtime_config(self):
343+
self._trainer_runtime_config.mode = DistributedMode.HALF_ASYNC
344+
345+
def check_program_config(self):
242346
self._program_config.sync_mode = False
243347
self._program_config.runtime_split_send_recv = True
244348
self._program_config.half_async = True
245-
self._build_strategy.async_mode = True
246-
self._execute_strategy.use_thread_barrier = True
247349

248-
num_threads = os.getenv("CPU_NUM", "1")
350+
def check_server_runtime_config(self):
351+
pass
352+
353+
def check_execute_strategy(self):
354+
self._execute_strategy.use_thread_barrier = True
249355

250-
self._trainer_runtime_config.runtime_configs[
251-
'communicator_max_merge_var_num'] = os.getenv(
252-
"FLAGS_communicator_max_merge_var_num", num_threads)
253-
self._trainer_runtime_config.runtime_configs[
254-
'communicator_send_wait_times'] = os.getenv(
255-
"FLAGS_communicator_send_wait_times", "5")
256-
self._trainer_runtime_config.runtime_configs[
257-
'communicator_thread_pool_size'] = os.getenv(
258-
"FLAGS_communicator_thread_pool_size", "10")
259-
self._trainer_runtime_config.runtime_configs[
260-
'communicator_send_queue_size'] = os.getenv(
261-
"FLAGS_communicator_send_queue_size", num_threads)
356+
def check_build_strategy(self):
357+
self._build_strategy.async_mode = True
262358

263359

264360
class GeoStrategy(DistributedStrategy):
265361
def __init__(self, update_frequency=100):
266362
super(GeoStrategy, self).__init__()
363+
self._program_config.geo_sgd_need_push_nums = update_frequency
364+
self.check_program_config()
365+
self.check_trainer_runtime_config()
366+
self.check_server_runtime_config()
367+
self.check_build_strategy()
368+
self.check_execute_strategy()
369+
370+
def check_program_config(self):
267371
self._program_config.sync_mode = False
268372
self._program_config.runtime_split_send_recv = True
269373
self._program_config.geo_sgd_mode = True
270-
self._program_config.geo_sgd_need_push_nums = update_frequency
271-
self._build_strategy.async_mode = True
272374

273-
self._trainer_runtime_config.runtime_configs[
274-
'communicator_thread_pool_size'] = os.getenv(
275-
"FLAGS_communicator_thread_pool_size", "10")
276-
self._trainer_runtime_config.runtime_configs[
277-
'communicator_send_wait_times'] = os.getenv(
278-
"FLAGS_communicator_send_wait_times", "5")
375+
def check_trainer_runtime_config(self):
376+
self._trainer_runtime_config.mode = DistributedMode.GEO
377+
378+
def check_server_runtime_config(self):
379+
pass
380+
381+
def check_execute_strategy(self):
382+
pass
383+
384+
def check_build_strategy(self):
385+
self._build_strategy.async_mode = True
279386

280387

281388
class StrategyFactory(object):

python/paddle/fluid/tests/unittests/test_distributed_strategy.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def test_sync_strategy(self):
5252
self.assertRaises(Exception, strategy.set_program_config,
5353
program_config_illegal)
5454

55+
trainer_runtime_config = strategy.get_trainer_runtime_config()
56+
trainer_runtime_config.runtime_configs[
57+
'communicator_send_queue_size'] = '50'
58+
runtime_configs = trainer_runtime_config.get_communicator_flags()
59+
self.assertIn('communicator_send_queue_size', runtime_configs)
60+
self.assertNotIn('communicator_independent_recv_thread',
61+
runtime_configs)
62+
self.assertEqual(runtime_configs['communicator_send_queue_size'], '2')
63+
5564
def test_geo_strategy(self):
5665
strategy = StrategyFactory.create_geo_strategy(5)
5766
self.assertEqual(strategy._program_config.sync_mode, False)
@@ -82,6 +91,14 @@ def test_geo_strategy(self):
8291
self.assertRaises(Exception, strategy.set_build_strategy,
8392
build_strategy_illegal)
8493

94+
os.environ["CPU_NUM"] = '100'
95+
trainer_runtime_config = strategy.get_trainer_runtime_config()
96+
runtime_configs = trainer_runtime_config.get_communicator_flags()
97+
self.assertIn('communicator_thread_pool_size', runtime_configs)
98+
self.assertIn('communicator_send_wait_times', runtime_configs)
99+
self.assertNotIn('communicator_independent_recv_thread',
100+
runtime_configs)
101+
85102
def test_async_strategy(self):
86103
os.environ["CPU_NUM"] = '100'
87104

@@ -164,6 +181,16 @@ def test_half_async_strategy(self):
164181
self.assertRaises(Exception, strategy.set_server_runtime_config,
165182
server_runtime_config_illegal)
166183

184+
os.environ["CPU_NUM"] = '100'
185+
trainer_runtime_config = strategy.get_trainer_runtime_config()
186+
trainer_runtime_config.runtime_configs[
187+
'communicator_send_queue_size'] = '50'
188+
runtime_configs = trainer_runtime_config.get_communicator_flags()
189+
self.assertIn('communicator_send_queue_size', runtime_configs)
190+
self.assertNotIn('communicator_independent_recv_thread',
191+
runtime_configs)
192+
self.assertEqual(runtime_configs['communicator_send_queue_size'], '100')
193+
167194

168195
class TestCreateDefaultStrategy(unittest.TestCase):
169196
def test_default_strategy(self):

0 commit comments

Comments
 (0)