Skip to content

Commit c000f8a

Browse files
authored
add texttable for pretty flag output (#22584) (#22626)
pretty print for communicator flag
1 parent f517fb6 commit c000f8a

File tree

2 files changed

+98
-53
lines changed

2 files changed

+98
-53
lines changed

python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,51 +24,35 @@
2424

2525
class TrainerRuntimeConfig(object):
2626
def __init__(self):
27-
self.max_merge_var_num = os.getenv(
28-
"FLAGS_communicator_max_merge_var_num", "20")
29-
self.send_queue_size = os.getenv("FLAGS_communicator_send_queue_size",
30-
"20")
31-
self.independent_recv_thread = os.getenv(
32-
"FLAGS_communicator_independent_recv_thread", "1")
33-
self.min_send_grad_num_before_recv = os.getenv(
34-
"FLAGS_communicator_min_send_grad_num_before_recv", "20")
35-
self.thread_pool_size = os.getenv("FLAGS_communicator_thread_pool_size",
36-
"5")
37-
self.send_wait_times = os.getenv("FLAGS_communicator_send_wait_times",
38-
"5")
39-
self.fake_rpc = os.getenv("FLAGS_communicator_fake_rpc", "0")
40-
self.merge_sparse_grad = os.getenv(
41-
"FLAGS_communicator_merge_sparse_grad", "1")
42-
self.is_sgd_optimizer = os.getenv("FLAGS_communicator_is_sgd_optimizer",
43-
"1")
44-
27+
self.runtime_configs = {}
4528
# not used
46-
self._rpc_deadline = os.getenv("FLAGS_rpc_deadline", "180000")
47-
self._rpc_retry_times = os.getenv("FLAGS_rpc_retry_times", "3")
29+
self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline",
30+
"180000")
31+
self.runtime_configs['rpc_retry_times'] = os.getenv(
32+
"FLAGS_rpc_retry_times", "3")
4833

4934
def get_communicator_flags(self):
50-
_communicator_flags = dict()
51-
_communicator_flags["communicator_max_merge_var_num"] = str(
52-
self.max_merge_var_num)
53-
_communicator_flags["communicator_send_queue_size"] = str(
54-
self.send_queue_size)
55-
_communicator_flags["communicator_independent_recv_thread"] = str(
56-
self.independent_recv_thread)
57-
_communicator_flags["communicator_min_send_grad_num_before_recv"] = str(
58-
self.min_send_grad_num_before_recv)
59-
_communicator_flags["communicator_thread_pool_size"] = str(
60-
self.thread_pool_size)
61-
_communicator_flags["communicator_send_wait_times"] = str(
62-
self.send_wait_times)
63-
_communicator_flags["communicator_is_sgd_optimizer"] = str(
64-
self.is_sgd_optimizer)
65-
return _communicator_flags
35+
return self.runtime_configs
6636

6737
def __repr__(self):
68-
_str = "please check that TrainerRuntimeConfig is as expected:\n"
69-
_communicator_flags = self.get_communicator_flags()
70-
for key in _communicator_flags:
71-
_str += "{}: {}\n".format(key, _communicator_flags[key])
38+
raw0, raw1, length = 45, 5, 50
39+
h_format = "{:^45s}{:<5s}\n"
40+
l_format = "{:<45s}{:<5s}\n"
41+
42+
border = "".join(["="] * length)
43+
line = "".join(["-"] * length)
44+
45+
draws = ""
46+
draws += border + "\n"
47+
draws += h_format.format("TrainerRuntimeConfig Overview", "Value")
48+
draws += line + "\n"
49+
50+
for k, v in self.get_communicator_flags().items():
51+
draws += l_format.format(k, v)
52+
53+
draws += border
54+
55+
_str = "\n{}\n".format(draws)
7256
return _str
7357

7458

@@ -77,9 +61,11 @@ def __init__(self):
7761
self._program_config = DistributeTranspilerConfig()
7862
self._trainer_runtime_config = TrainerRuntimeConfig()
7963
self._server_runtime_config = ServerRuntimeConfig()
64+
num_threads = int(os.getenv("CPU_NUM", "1"))
65+
8066
self._execute_strategy = fluid.ExecutionStrategy()
8167
self._build_strategy = fluid.BuildStrategy()
82-
num_threads = int(os.getenv("CPU_NUM", "1"))
68+
8369
self._execute_strategy.num_threads = num_threads
8470
if num_threads > 1:
8571
self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
@@ -110,9 +96,9 @@ def set_trainer_runtime_config(self, config):
11096
if isinstance(config, TrainerRuntimeConfig):
11197
self._trainer_runtime_config = config
11298
elif isinstance(config, dict):
113-
for key in config:
114-
if hasattr(self._trainer_runtime_config, key):
115-
setattr(self._trainer_runtime_config, key, config[key])
99+
for key, Value in config.items():
100+
if key in self._trainer_runtime_config.runtime_configs:
101+
self._trainer_runtime_config.runtime_configs[key] = Value
116102
else:
117103
raise ValueError(
118104
"TrainerRuntimeConfig doesn't have key: {}".format(key))
@@ -182,6 +168,21 @@ def __init__(self):
182168
self._program_config.runtime_split_send_recv = False
183169
self._build_strategy.async_mode = False
184170

171+
num_threads = os.getenv("CPU_NUM", "1")
172+
173+
self._trainer_runtime_config.runtime_configs[
174+
'communicator_max_merge_var_num'] = os.getenv(
175+
"FLAGS_communicator_max_merge_var_num", num_threads)
176+
self._trainer_runtime_config.runtime_configs[
177+
'communicator_send_wait_times'] = os.getenv(
178+
"FLAGS_communicator_send_wait_times", "5")
179+
self._trainer_runtime_config.runtime_configs[
180+
'communicator_thread_pool_size'] = os.getenv(
181+
"FLAGS_communicator_thread_pool_size", "10")
182+
self._trainer_runtime_config.runtime_configs[
183+
'communicator_send_queue_size'] = os.getenv(
184+
"FLAGS_communicator_send_queue_size", num_threads)
185+
185186

186187
class AsyncStrategy(DistributedStrategy):
187188
def __init__(self):
@@ -190,6 +191,30 @@ def __init__(self):
190191
self._program_config.runtime_split_send_recv = True
191192
self._build_strategy.async_mode = True
192193

194+
num_threads = os.getenv("CPU_NUM", "1")
195+
196+
self._trainer_runtime_config.runtime_configs[
197+
'communicator_max_merge_var_num'] = os.getenv(
198+
"FLAGS_communicator_max_merge_var_num", num_threads)
199+
self._trainer_runtime_config.runtime_configs[
200+
'communicator_independent_recv_thread'] = os.getenv(
201+
"FLAGS_communicator_independent_recv_thread", "0")
202+
self._trainer_runtime_config.runtime_configs[
203+
'communicator_min_send_grad_num_before_recv'] = os.getenv(
204+
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
205+
self._trainer_runtime_config.runtime_configs[
206+
'communicator_thread_pool_size'] = os.getenv(
207+
"FLAGS_communicator_thread_pool_size", "10")
208+
self._trainer_runtime_config.runtime_configs[
209+
'communicator_send_wait_times'] = os.getenv(
210+
"FLAGS_communicator_send_wait_times", "5")
211+
self._trainer_runtime_config.runtime_configs[
212+
'communicator_is_sgd_optimizer'] = os.getenv(
213+
"FLAGS_communicator_is_sgd_optimizer", "1")
214+
self._trainer_runtime_config.runtime_configs[
215+
'communicator_send_queue_size'] = os.getenv(
216+
"FLAGS_communicator_send_queue_size", num_threads)
217+
193218

194219
class HalfAsyncStrategy(DistributedStrategy):
195220
def __init__(self):
@@ -200,15 +225,37 @@ def __init__(self):
200225
self._build_strategy.async_mode = True
201226
self._execute_strategy.use_thread_barrier = True
202227

228+
num_threads = os.getenv("CPU_NUM", "1")
229+
230+
self._trainer_runtime_config.runtime_configs[
231+
'communicator_max_merge_var_num'] = os.getenv(
232+
"FLAGS_communicator_max_merge_var_num", num_threads)
233+
self._trainer_runtime_config.runtime_configs[
234+
'communicator_send_wait_times'] = os.getenv(
235+
"FLAGS_communicator_send_wait_times", "5")
236+
self._trainer_runtime_config.runtime_configs[
237+
'communicator_thread_pool_size'] = os.getenv(
238+
"FLAGS_communicator_thread_pool_size", "10")
239+
self._trainer_runtime_config.runtime_configs[
240+
'communicator_send_queue_size'] = os.getenv(
241+
"FLAGS_communicator_send_queue_size", num_threads)
242+
203243

204244
class GeoStrategy(DistributedStrategy):
205245
def __init__(self, update_frequency=100):
206246
super(GeoStrategy, self).__init__()
207247
self._program_config.sync_mode = False
208248
self._program_config.runtime_split_send_recv = True
209-
self._build_strategy.async_mode = True
210249
self._program_config.geo_sgd_mode = True
211250
self._program_config.geo_sgd_need_push_nums = update_frequency
251+
self._build_strategy.async_mode = True
252+
253+
self._trainer_runtime_config.runtime_configs[
254+
'communicator_thread_pool_size'] = os.getenv(
255+
"FLAGS_communicator_thread_pool_size", "10")
256+
self._trainer_runtime_config.runtime_configs[
257+
'communicator_send_wait_times'] = os.getenv(
258+
"FLAGS_communicator_send_wait_times", "5")
212259

213260

214261
class StrategyFactory(object):

python/paddle/fluid/tests/unittests/test_distributed_strategy.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,30 +84,28 @@ def test_geo_strategy(self):
8484
build_strategy_illegal)
8585

8686
def test_async_strategy(self):
87+
os.environ["CPU_NUM"] = '100'
88+
8789
strategy = StrategyFactory.create_async_strategy()
8890
self.assertEqual(strategy._program_config.sync_mode, False)
8991
self.assertEqual(strategy._program_config.runtime_split_send_recv, True)
9092
self.assertEqual(strategy._build_strategy.async_mode, True)
9193

92-
# test set_trainer_runtime_config using TrainerRuntimeConfig
93-
trainer_runtime_config_class = TrainerRuntimeConfig()
94-
trainer_runtime_config_class.send_queue_size = 50
95-
print(trainer_runtime_config_class)
96-
strategy.set_trainer_runtime_config(trainer_runtime_config_class)
9794
trainer_runtime_config = strategy.get_trainer_runtime_config()
98-
self.assertEqual(trainer_runtime_config.send_queue_size, 50)
95+
self.assertEqual(trainer_runtime_config.runtime_configs[
96+
'communicator_send_queue_size'], '100')
9997

10098
# test set_trainer_runtime_config using dict
10199
trainer_runtime_config_dict = dict()
102-
trainer_runtime_config_dict['send_queue_size'] = 100
100+
trainer_runtime_config_dict['communicator_send_queue_size'] = '20'
103101
strategy.set_trainer_runtime_config(trainer_runtime_config_dict)
104102
trainer_runtime_config = strategy.get_trainer_runtime_config()
105103
trainer_communicator_flags = trainer_runtime_config.get_communicator_flags(
106104
)
107105
self.assertIn('communicator_send_queue_size',
108106
trainer_communicator_flags)
109107
self.assertEqual(
110-
trainer_communicator_flags['communicator_send_queue_size'], '100')
108+
trainer_communicator_flags['communicator_send_queue_size'], '20')
111109

112110
# test set_trainer_runtime_config exception
113111
trainer_runtime_config_dict['unknown'] = None

0 commit comments

Comments
 (0)