19
19
20
20
import os
21
21
import paddle .fluid as fluid
22
- from paddle .fluid .transpiler .distribute_transpiler import DistributeTranspilerConfig , ServerRuntimeConfig
22
+ from paddle .fluid .transpiler .distribute_transpiler import DistributeTranspilerConfig , ServerRuntimeConfig , DistributedMode
23
23
24
24
25
25
class TrainerRuntimeConfig (object ):
26
26
def __init__ (self ):
27
+ self .mode = None
28
+ num_threads = os .getenv ("CPU_NUM" , "1" )
29
+
27
30
self .runtime_configs = {}
31
+ self .runtime_configs ['communicator_max_merge_var_num' ] = os .getenv (
32
+ "FLAGS_communicator_max_merge_var_num" , num_threads )
33
+ self .runtime_configs ['communicator_send_queue_size' ] = os .getenv (
34
+ "FLAGS_communicator_send_queue_size" , num_threads )
35
+ self .runtime_configs [
36
+ 'communicator_independent_recv_thread' ] = os .getenv (
37
+ "FLAGS_communicator_independent_recv_thread" , "1" )
38
+ self .runtime_configs [
39
+ 'communicator_min_send_grad_num_before_recv' ] = os .getenv (
40
+ "FLAGS_communicator_min_send_grad_num_before_recv" , num_threads )
41
+ self .runtime_configs ['communicator_thread_pool_size' ] = os .getenv (
42
+ "FLAGS_communicator_thread_pool_size" , "5" )
43
+ self .runtime_configs ['communicator_send_wait_times' ] = os .getenv (
44
+ "FLAGS_communicator_send_wait_times" , "5" )
45
+ self .runtime_configs ['communicator_is_sgd_optimizer' ] = os .getenv (
46
+ "FLAGS_communicator_is_sgd_optimizer" , "1" )
47
+
28
48
# not used
29
49
self .runtime_configs ['rpc_deadline' ] = os .getenv ("FLAGS_rpc_deadline" ,
30
50
"180000" )
31
51
self .runtime_configs ['rpc_retry_times' ] = os .getenv (
32
52
"FLAGS_rpc_retry_times" , "3" )
33
53
34
54
def get_communicator_flags (self ):
35
- return self .runtime_configs
36
-
37
- def __repr__ (self ):
55
+ need_keys = []
56
+ num_threads = os .getenv ("CPU_NUM" , "1" )
57
+ mode_str = ""
58
+ if self .mode is None or self .mode == DistributedMode .ASYNC :
59
+ need_keys = self .runtime_configs .keys ()
60
+ mode_str = "async"
61
+ elif self .mode == DistributedMode .SYNC or self .mode == DistributedMode .HALF_ASYNC :
62
+ mode_str = "sync or half_async"
63
+ need_keys = [
64
+ 'communicator_max_merge_var_num' ,
65
+ 'communicator_send_wait_times' , 'communicator_thread_pool_size' ,
66
+ 'communicator_send_queue_size'
67
+ ]
68
+ elif self .mode == DistributedMode .GEO :
69
+ mode_str = "GEO"
70
+ need_keys = [
71
+ 'communicator_thread_pool_size' , 'communicator_send_wait_times'
72
+ ]
73
+ else :
74
+ raise ValueError ("Unsupported Mode" )
75
+
76
+ if self .mode == DistributedMode .SYNC or self .mode == DistributedMode .HALF_ASYNC :
77
+ max_merge_var_num = self .runtime_configs [
78
+ 'communicator_max_merge_var_num' ]
79
+ send_queue_size = self .runtime_configs [
80
+ 'communicator_send_queue_size' ]
81
+ if max_merge_var_num != num_threads :
82
+ print ('WARNING: In {} mode, communicator_max_merge_var_num '
83
+ 'must be equal to CPU_NUM. But received, '
84
+ 'communicator_max_merge_var_num = {}, CPU_NUM = '
85
+ '{}. communicator_max_merge_var_num will be fored to {}.'
86
+ .format (mode_str , max_merge_var_num , num_threads ,
87
+ num_threads ))
88
+ self .runtime_configs [
89
+ 'communicator_max_merge_var_num' ] = num_threads
90
+ if send_queue_size != num_threads :
91
+ print ('WARNING: In {} mode, communicator_send_queue_size '
92
+ 'must be equal to CPU_NUM. But received, '
93
+ 'communicator_send_queue_size = {}, CPU_NUM = '
94
+ '{}. communicator_send_queue_size will be fored to {}.'
95
+ .format (mode_str , send_queue_size , num_threads ,
96
+ num_threads ))
97
+ self .runtime_configs [
98
+ 'communicator_send_queue_size' ] = num_threads
99
+
100
+ return dict ((key , str (self .runtime_configs [key ])) for key in need_keys )
101
+
102
+ def display (self , configs ):
38
103
raw0 , raw1 , length = 45 , 5 , 50
39
104
h_format = "{:^45s}{:<5s}\n "
40
105
l_format = "{:<45s}{:<5s}\n "
@@ -47,14 +112,17 @@ def __repr__(self):
47
112
draws += h_format .format ("TrainerRuntimeConfig Overview" , "Value" )
48
113
draws += line + "\n "
49
114
50
- for k , v in self . get_communicator_flags () .items ():
115
+ for k , v in configs .items ():
51
116
draws += l_format .format (k , v )
52
117
53
118
draws += border
54
119
55
120
_str = "\n {}\n " .format (draws )
56
121
return _str
57
122
123
+ def __repr__ (self ):
124
+ return self .display (self .get_communicator_flags ())
125
+
58
126
59
127
class DistributedStrategy (object ):
60
128
def __init__ (self ):
@@ -105,6 +173,12 @@ def set_program_config(self, config):
105
173
raise TypeError (
106
174
"program_config only accept input type: dict or DistributeTranspilerConfig"
107
175
)
176
+ self .check_program_config ()
177
+
178
+ def check_program_config (self ):
179
+ raise NotImplementedError (
180
+ "check_program_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
181
+ )
108
182
109
183
def get_trainer_runtime_config (self ):
110
184
return self ._trainer_runtime_config
@@ -123,6 +197,12 @@ def set_trainer_runtime_config(self, config):
123
197
raise TypeError (
124
198
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
125
199
)
200
+ self .check_trainer_runtime_config ()
201
+
202
+ def check_trainer_runtime_config (self ):
203
+ raise NotImplementedError (
204
+ "check_trainer_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
205
+ )
126
206
127
207
def get_server_runtime_config (self ):
128
208
return self ._server_runtime_config
@@ -141,6 +221,12 @@ def set_server_runtime_config(self, config):
141
221
raise TypeError (
142
222
"server_runtime_config only accept input type: dict or ServerRuntimeConfig"
143
223
)
224
+ self .check_server_runtime_config ()
225
+
226
+ def check_server_runtime_config (self ):
227
+ raise NotImplementedError (
228
+ "check_server_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
229
+ )
144
230
145
231
def get_execute_strategy (self ):
146
232
return self ._execute_strategy
@@ -159,6 +245,12 @@ def set_execute_strategy(self, config):
159
245
raise TypeError (
160
246
"execute_strategy only accept input type: dict or ExecutionStrategy"
161
247
)
248
+ self .check_execute_strategy ()
249
+
250
+ def check_execute_strategy (self ):
251
+ raise NotImplementedError (
252
+ "check_execute_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
253
+ )
162
254
163
255
def get_build_strategy (self ):
164
256
return self ._build_strategy
@@ -176,106 +268,121 @@ def set_build_strategy(self, config):
176
268
else :
177
269
raise TypeError (
178
270
"build_strategy only accept input type: dict or BuildStrategy" )
271
+ self .check_build_strategy ()
272
+
273
+ def check_build_strategy (self ):
274
+ raise NotImplementedError (
275
+ "check_build_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
276
+ )
179
277
180
278
181
279
class SyncStrategy (DistributedStrategy ):
182
280
def __init__ (self ):
183
281
super (SyncStrategy , self ).__init__ ()
282
+ self .check_program_config ()
283
+ self .check_trainer_runtime_config ()
284
+ self .check_server_runtime_config ()
285
+ self .check_build_strategy ()
286
+ self .check_execute_strategy ()
287
+
288
+ def check_trainer_runtime_config (self ):
289
+ self ._trainer_runtime_config .mode = DistributedMode .SYNC
290
+
291
+ def check_program_config (self ):
184
292
self ._program_config .sync_mode = False
185
293
self ._program_config .runtime_split_send_recv = True
186
- self ._build_strategy .async_mode = True
187
294
self ._program_config .half_async = True
188
295
self ._program_config .completely_not_async = True
189
- self ._execute_strategy .use_thread_barrier = True
190
296
191
- num_threads = os .getenv ("CPU_NUM" , "1" )
297
+ def check_server_runtime_config (self ):
298
+ pass
192
299
193
- self ._trainer_runtime_config .runtime_configs [
194
- 'communicator_max_merge_var_num' ] = os .getenv (
195
- "FLAGS_communicator_max_merge_var_num" , num_threads )
196
- self ._trainer_runtime_config .runtime_configs [
197
- 'communicator_send_wait_times' ] = os .getenv (
198
- "FLAGS_communicator_send_wait_times" , "5" )
199
- self ._trainer_runtime_config .runtime_configs [
200
- 'communicator_thread_pool_size' ] = os .getenv (
201
- "FLAGS_communicator_thread_pool_size" , "10" )
202
- self ._trainer_runtime_config .runtime_configs [
203
- 'communicator_send_queue_size' ] = os .getenv (
204
- "FLAGS_communicator_send_queue_size" , num_threads )
300
+ def check_execute_strategy (self ):
301
+ self ._execute_strategy .use_thread_barrier = True
302
+
303
+ def check_build_strategy (self ):
304
+ self ._build_strategy .async_mode = True
205
305
206
306
207
307
class AsyncStrategy (DistributedStrategy ):
208
308
def __init__ (self ):
209
309
super (AsyncStrategy , self ).__init__ ()
310
+ self .check_program_config ()
311
+ self .check_trainer_runtime_config ()
312
+ self .check_server_runtime_config ()
313
+ self .check_build_strategy ()
314
+ self .check_execute_strategy ()
315
+
316
+ def check_trainer_runtime_config (self ):
317
+ self ._trainer_runtime_config .mode = DistributedMode .ASYNC
318
+
319
+ def check_program_config (self ):
210
320
self ._program_config .sync_mode = False
211
321
self ._program_config .runtime_split_send_recv = True
212
- self ._build_strategy .async_mode = True
213
322
214
- num_threads = os .getenv ("CPU_NUM" , "1" )
323
+ def check_server_runtime_config (self ):
324
+ pass
215
325
216
- self ._trainer_runtime_config .runtime_configs [
217
- 'communicator_max_merge_var_num' ] = os .getenv (
218
- "FLAGS_communicator_max_merge_var_num" , num_threads )
219
- self ._trainer_runtime_config .runtime_configs [
220
- 'communicator_independent_recv_thread' ] = os .getenv (
221
- "FLAGS_communicator_independent_recv_thread" , "0" )
222
- self ._trainer_runtime_config .runtime_configs [
223
- 'communicator_min_send_grad_num_before_recv' ] = os .getenv (
224
- "FLAGS_communicator_min_send_grad_num_before_recv" , num_threads )
225
- self ._trainer_runtime_config .runtime_configs [
226
- 'communicator_thread_pool_size' ] = os .getenv (
227
- "FLAGS_communicator_thread_pool_size" , "10" )
228
- self ._trainer_runtime_config .runtime_configs [
229
- 'communicator_send_wait_times' ] = os .getenv (
230
- "FLAGS_communicator_send_wait_times" , "5" )
231
- self ._trainer_runtime_config .runtime_configs [
232
- 'communicator_is_sgd_optimizer' ] = os .getenv (
233
- "FLAGS_communicator_is_sgd_optimizer" , "1" )
234
- self ._trainer_runtime_config .runtime_configs [
235
- 'communicator_send_queue_size' ] = os .getenv (
236
- "FLAGS_communicator_send_queue_size" , num_threads )
326
+ def check_execute_strategy (self ):
327
+ pass
328
+
329
+ def check_build_strategy (self ):
330
+ self ._build_strategy .async_mode = True
237
331
238
332
239
333
class HalfAsyncStrategy (DistributedStrategy ):
240
334
def __init__ (self ):
241
335
super (HalfAsyncStrategy , self ).__init__ ()
336
+ self .check_program_config ()
337
+ self .check_trainer_runtime_config ()
338
+ self .check_server_runtime_config ()
339
+ self .check_build_strategy ()
340
+ self .check_execute_strategy ()
341
+
342
+ def check_trainer_runtime_config (self ):
343
+ self ._trainer_runtime_config .mode = DistributedMode .HALF_ASYNC
344
+
345
+ def check_program_config (self ):
242
346
self ._program_config .sync_mode = False
243
347
self ._program_config .runtime_split_send_recv = True
244
348
self ._program_config .half_async = True
245
- self ._build_strategy .async_mode = True
246
- self ._execute_strategy .use_thread_barrier = True
247
349
248
- num_threads = os .getenv ("CPU_NUM" , "1" )
350
+ def check_server_runtime_config (self ):
351
+ pass
352
+
353
+ def check_execute_strategy (self ):
354
+ self ._execute_strategy .use_thread_barrier = True
249
355
250
- self ._trainer_runtime_config .runtime_configs [
251
- 'communicator_max_merge_var_num' ] = os .getenv (
252
- "FLAGS_communicator_max_merge_var_num" , num_threads )
253
- self ._trainer_runtime_config .runtime_configs [
254
- 'communicator_send_wait_times' ] = os .getenv (
255
- "FLAGS_communicator_send_wait_times" , "5" )
256
- self ._trainer_runtime_config .runtime_configs [
257
- 'communicator_thread_pool_size' ] = os .getenv (
258
- "FLAGS_communicator_thread_pool_size" , "10" )
259
- self ._trainer_runtime_config .runtime_configs [
260
- 'communicator_send_queue_size' ] = os .getenv (
261
- "FLAGS_communicator_send_queue_size" , num_threads )
356
+ def check_build_strategy (self ):
357
+ self ._build_strategy .async_mode = True
262
358
263
359
264
360
class GeoStrategy (DistributedStrategy ):
265
361
def __init__ (self , update_frequency = 100 ):
266
362
super (GeoStrategy , self ).__init__ ()
363
+ self ._program_config .geo_sgd_need_push_nums = update_frequency
364
+ self .check_program_config ()
365
+ self .check_trainer_runtime_config ()
366
+ self .check_server_runtime_config ()
367
+ self .check_build_strategy ()
368
+ self .check_execute_strategy ()
369
+
370
+ def check_program_config (self ):
267
371
self ._program_config .sync_mode = False
268
372
self ._program_config .runtime_split_send_recv = True
269
373
self ._program_config .geo_sgd_mode = True
270
- self ._program_config .geo_sgd_need_push_nums = update_frequency
271
- self ._build_strategy .async_mode = True
272
374
273
- self ._trainer_runtime_config .runtime_configs [
274
- 'communicator_thread_pool_size' ] = os .getenv (
275
- "FLAGS_communicator_thread_pool_size" , "10" )
276
- self ._trainer_runtime_config .runtime_configs [
277
- 'communicator_send_wait_times' ] = os .getenv (
278
- "FLAGS_communicator_send_wait_times" , "5" )
375
+ def check_trainer_runtime_config (self ):
376
+ self ._trainer_runtime_config .mode = DistributedMode .GEO
377
+
378
+ def check_server_runtime_config (self ):
379
+ pass
380
+
381
+ def check_execute_strategy (self ):
382
+ pass
383
+
384
+ def check_build_strategy (self ):
385
+ self ._build_strategy .async_mode = True
279
386
280
387
281
388
class StrategyFactory (object ):
0 commit comments