24
24
25
25
class TrainerRuntimeConfig (object ):
26
26
def __init__ (self ):
27
- self .max_merge_var_num = os .getenv (
28
- "FLAGS_communicator_max_merge_var_num" , "20" )
29
- self .send_queue_size = os .getenv ("FLAGS_communicator_send_queue_size" ,
30
- "20" )
31
- self .independent_recv_thread = os .getenv (
32
- "FLAGS_communicator_independent_recv_thread" , "1" )
33
- self .min_send_grad_num_before_recv = os .getenv (
34
- "FLAGS_communicator_min_send_grad_num_before_recv" , "20" )
35
- self .thread_pool_size = os .getenv ("FLAGS_communicator_thread_pool_size" ,
36
- "5" )
37
- self .send_wait_times = os .getenv ("FLAGS_communicator_send_wait_times" ,
38
- "5" )
39
- self .fake_rpc = os .getenv ("FLAGS_communicator_fake_rpc" , "0" )
40
- self .merge_sparse_grad = os .getenv (
41
- "FLAGS_communicator_merge_sparse_grad" , "1" )
42
- self .is_sgd_optimizer = os .getenv ("FLAGS_communicator_is_sgd_optimizer" ,
43
- "1" )
44
-
27
+ self .runtime_configs = {}
45
28
# not used
46
- self ._rpc_deadline = os .getenv ("FLAGS_rpc_deadline" , "180000" )
47
- self ._rpc_retry_times = os .getenv ("FLAGS_rpc_retry_times" , "3" )
29
+ self .runtime_configs ['rpc_deadline' ] = os .getenv ("FLAGS_rpc_deadline" ,
30
+ "180000" )
31
+ self .runtime_configs ['rpc_retry_times' ] = os .getenv (
32
+ "FLAGS_rpc_retry_times" , "3" )
48
33
49
34
def get_communicator_flags (self ):
50
- _communicator_flags = dict ()
51
- _communicator_flags ["communicator_max_merge_var_num" ] = str (
52
- self .max_merge_var_num )
53
- _communicator_flags ["communicator_send_queue_size" ] = str (
54
- self .send_queue_size )
55
- _communicator_flags ["communicator_independent_recv_thread" ] = str (
56
- self .independent_recv_thread )
57
- _communicator_flags ["communicator_min_send_grad_num_before_recv" ] = str (
58
- self .min_send_grad_num_before_recv )
59
- _communicator_flags ["communicator_thread_pool_size" ] = str (
60
- self .thread_pool_size )
61
- _communicator_flags ["communicator_send_wait_times" ] = str (
62
- self .send_wait_times )
63
- _communicator_flags ["communicator_is_sgd_optimizer" ] = str (
64
- self .is_sgd_optimizer )
65
- return _communicator_flags
35
+ return self .runtime_configs
66
36
67
37
def __repr__ (self ):
68
- _str = "please check that TrainerRuntimeConfig is as expected:\n "
69
- _communicator_flags = self .get_communicator_flags ()
70
- for key in _communicator_flags :
71
- _str += "{}: {}\n " .format (key , _communicator_flags [key ])
38
+ raw0 , raw1 , length = 45 , 5 , 50
39
+ h_format = "{:^45s}{:<5s}\n "
40
+ l_format = "{:<45s}{:<5s}\n "
41
+
42
+ border = "" .join (["=" ] * length )
43
+ line = "" .join (["-" ] * length )
44
+
45
+ draws = ""
46
+ draws += border + "\n "
47
+ draws += h_format .format ("TrainerRuntimeConfig Overview" , "Value" )
48
+ draws += line + "\n "
49
+
50
+ for k , v in self .get_communicator_flags ().items ():
51
+ draws += l_format .format (k , v )
52
+
53
+ draws += border
54
+
55
+ _str = "\n {}\n " .format (draws )
72
56
return _str
73
57
74
58
@@ -77,9 +61,11 @@ def __init__(self):
77
61
self ._program_config = DistributeTranspilerConfig ()
78
62
self ._trainer_runtime_config = TrainerRuntimeConfig ()
79
63
self ._server_runtime_config = ServerRuntimeConfig ()
64
+ num_threads = int (os .getenv ("CPU_NUM" , "1" ))
65
+
80
66
self ._execute_strategy = fluid .ExecutionStrategy ()
81
67
self ._build_strategy = fluid .BuildStrategy ()
82
- num_threads = int ( os . getenv ( "CPU_NUM" , "1" ))
68
+
83
69
self ._execute_strategy .num_threads = num_threads
84
70
if num_threads > 1 :
85
71
self ._build_strategy .reduce_strategy = fluid .BuildStrategy .ReduceStrategy .Reduce
@@ -110,9 +96,9 @@ def set_trainer_runtime_config(self, config):
110
96
if isinstance (config , TrainerRuntimeConfig ):
111
97
self ._trainer_runtime_config = config
112
98
elif isinstance (config , dict ):
113
- for key in config :
114
- if hasattr ( self ._trainer_runtime_config , key ) :
115
- setattr ( self ._trainer_runtime_config , key , config [key ])
99
+ for key , Value in config . items () :
100
+ if key in self ._trainer_runtime_config . runtime_configs :
101
+ self ._trainer_runtime_config . runtime_configs [key ] = Value
116
102
else :
117
103
raise ValueError (
118
104
"TrainerRuntimeConfig doesn't have key: {}" .format (key ))
@@ -182,6 +168,21 @@ def __init__(self):
182
168
self ._program_config .runtime_split_send_recv = False
183
169
self ._build_strategy .async_mode = False
184
170
171
+ num_threads = os .getenv ("CPU_NUM" , "1" )
172
+
173
+ self ._trainer_runtime_config .runtime_configs [
174
+ 'communicator_max_merge_var_num' ] = os .getenv (
175
+ "FLAGS_communicator_max_merge_var_num" , num_threads )
176
+ self ._trainer_runtime_config .runtime_configs [
177
+ 'communicator_send_wait_times' ] = os .getenv (
178
+ "FLAGS_communicator_send_wait_times" , "5" )
179
+ self ._trainer_runtime_config .runtime_configs [
180
+ 'communicator_thread_pool_size' ] = os .getenv (
181
+ "FLAGS_communicator_thread_pool_size" , "10" )
182
+ self ._trainer_runtime_config .runtime_configs [
183
+ 'communicator_send_queue_size' ] = os .getenv (
184
+ "FLAGS_communicator_send_queue_size" , num_threads )
185
+
185
186
186
187
class AsyncStrategy (DistributedStrategy ):
187
188
def __init__ (self ):
@@ -190,6 +191,30 @@ def __init__(self):
190
191
self ._program_config .runtime_split_send_recv = True
191
192
self ._build_strategy .async_mode = True
192
193
194
+ num_threads = os .getenv ("CPU_NUM" , "1" )
195
+
196
+ self ._trainer_runtime_config .runtime_configs [
197
+ 'communicator_max_merge_var_num' ] = os .getenv (
198
+ "FLAGS_communicator_max_merge_var_num" , num_threads )
199
+ self ._trainer_runtime_config .runtime_configs [
200
+ 'communicator_independent_recv_thread' ] = os .getenv (
201
+ "FLAGS_communicator_independent_recv_thread" , "0" )
202
+ self ._trainer_runtime_config .runtime_configs [
203
+ 'communicator_min_send_grad_num_before_recv' ] = os .getenv (
204
+ "FLAGS_communicator_min_send_grad_num_before_recv" , num_threads )
205
+ self ._trainer_runtime_config .runtime_configs [
206
+ 'communicator_thread_pool_size' ] = os .getenv (
207
+ "FLAGS_communicator_thread_pool_size" , "10" )
208
+ self ._trainer_runtime_config .runtime_configs [
209
+ 'communicator_send_wait_times' ] = os .getenv (
210
+ "FLAGS_communicator_send_wait_times" , "5" )
211
+ self ._trainer_runtime_config .runtime_configs [
212
+ 'communicator_is_sgd_optimizer' ] = os .getenv (
213
+ "FLAGS_communicator_is_sgd_optimizer" , "1" )
214
+ self ._trainer_runtime_config .runtime_configs [
215
+ 'communicator_send_queue_size' ] = os .getenv (
216
+ "FLAGS_communicator_send_queue_size" , num_threads )
217
+
193
218
194
219
class HalfAsyncStrategy (DistributedStrategy ):
195
220
def __init__ (self ):
@@ -200,15 +225,37 @@ def __init__(self):
200
225
self ._build_strategy .async_mode = True
201
226
self ._execute_strategy .use_thread_barrier = True
202
227
228
+ num_threads = os .getenv ("CPU_NUM" , "1" )
229
+
230
+ self ._trainer_runtime_config .runtime_configs [
231
+ 'communicator_max_merge_var_num' ] = os .getenv (
232
+ "FLAGS_communicator_max_merge_var_num" , num_threads )
233
+ self ._trainer_runtime_config .runtime_configs [
234
+ 'communicator_send_wait_times' ] = os .getenv (
235
+ "FLAGS_communicator_send_wait_times" , "5" )
236
+ self ._trainer_runtime_config .runtime_configs [
237
+ 'communicator_thread_pool_size' ] = os .getenv (
238
+ "FLAGS_communicator_thread_pool_size" , "10" )
239
+ self ._trainer_runtime_config .runtime_configs [
240
+ 'communicator_send_queue_size' ] = os .getenv (
241
+ "FLAGS_communicator_send_queue_size" , num_threads )
242
+
203
243
204
244
class GeoStrategy (DistributedStrategy ):
205
245
def __init__ (self , update_frequency = 100 ):
206
246
super (GeoStrategy , self ).__init__ ()
207
247
self ._program_config .sync_mode = False
208
248
self ._program_config .runtime_split_send_recv = True
209
- self ._build_strategy .async_mode = True
210
249
self ._program_config .geo_sgd_mode = True
211
250
self ._program_config .geo_sgd_need_push_nums = update_frequency
251
+ self ._build_strategy .async_mode = True
252
+
253
+ self ._trainer_runtime_config .runtime_configs [
254
+ 'communicator_thread_pool_size' ] = os .getenv (
255
+ "FLAGS_communicator_thread_pool_size" , "10" )
256
+ self ._trainer_runtime_config .runtime_configs [
257
+ 'communicator_send_wait_times' ] = os .getenv (
258
+ "FLAGS_communicator_send_wait_times" , "5" )
212
259
213
260
214
261
class StrategyFactory (object ):
0 commit comments