@@ -121,6 +121,7 @@ def split_dense_variable(var_list,
121
121
block_size += dim1 - remains
122
122
# update split_count after aligning
123
123
split_count = int (math .ceil (var_numel / float (block_size )))
124
+ print ("###split var " , var .name , var .shape , block_size , split_count )
124
125
for block_id in xrange (split_count ):
125
126
curr_block_size = min (block_size , var_numel - (
126
127
(block_id ) * block_size ))
@@ -191,7 +192,6 @@ def transpile(self,
191
192
for b in param_blocks :
192
193
varname , block_id , _ = b .split (":" )
193
194
send_outputs .append (param_var_mapping [varname ][int (block_id )])
194
-
195
195
# let send_op know which endpoint to send which var to, eplist has the same
196
196
# order as send_inputs.
197
197
eplist = split_method (send_inputs , pserver_endpoints )
@@ -230,21 +230,6 @@ def transpile(self,
230
230
outputs = {"Out" : [orig_param ]},
231
231
attrs = {"axis" : 0 })
232
232
233
- self .lr_param_mapping = self ._create_lr_param_mapping ()
234
-
235
- def _create_lr_param_mapping (self ):
236
- lr_mapping = dict ()
237
- for _ , opt_op in enumerate (self .optimize_ops ):
238
- if not opt_op .inputs or not opt_op .inputs .has_key ("LearningRate" ) \
239
- or not opt_op .inputs .has_key ("Param" ):
240
- continue
241
- lr = opt_op .inputs ["LearningRate" ].name
242
- param = opt_op .inputs ["Param" ].name
243
- if not lr_mapping .has_key (lr ):
244
- lr_mapping .update ({lr : list ()})
245
- lr_mapping [lr ].append (param )
246
- return lr_mapping
247
-
248
233
def _create_vars_from_blocklist (self , program , block_list ):
249
234
# Create respective variables using the block_list
250
235
block_map = dict ()
@@ -271,13 +256,15 @@ def _create_vars_from_blocklist(self, program, block_list):
271
256
splited_shape = [rows ]
272
257
if len (orig_shape ) >= 2 :
273
258
splited_shape .extend (orig_shape [1 :])
259
+ print ("###splited: " , size , rows , splited_shape )
274
260
var = program .global_block ().create_var (
275
261
name = "%s.block%d" % (varname , i ),
276
262
psersistable = False ,
277
263
dtype = orig_var .dtype ,
278
264
type = orig_var .type ,
279
265
shape = splited_shape ) # flattend splited var
280
266
var_mapping [varname ].append (var )
267
+ print ("###created split var " , var )
281
268
return var_mapping
282
269
283
270
def _clone_var (self , block , var ):
@@ -369,18 +356,9 @@ def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
369
356
pass
370
357
return orig_shape
371
358
372
- def _fetch_var_names (self , param_dict ):
373
- res = []
374
- if not param_dict :
375
- return res
376
- for _ , values in param_dict .iteritems ():
377
- if not isinstance (values , list ):
378
- values = [values ]
379
- res += [v .name for v in values ]
380
- return res
381
-
382
359
def _append_pserver_ops (self , optimize_block , opt_op , endpoint ):
383
360
program = optimize_block .program
361
+ pserver_block = program .global_block ()
384
362
new_inputs = dict ()
385
363
# update param/grad shape first, then other inputs like
386
364
# moment can use the updated shape
@@ -395,11 +373,11 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
395
373
# do not append this op if current endpoint
396
374
# is not dealing with this grad block
397
375
return
398
- merged_var = program . global_block () .vars [grad_block .name ]
376
+ merged_var = pserver_block .vars [grad_block .name ]
399
377
# append merging ops if trainers > 1
400
378
if self .trainers > 1 :
401
379
vars2merge = self ._create_var_for_trainers (
402
- program . global_block () , grad_block , self .trainers )
380
+ pserver_block , grad_block , self .trainers )
403
381
optimize_block .append_op (
404
382
type = "sum" ,
405
383
inputs = {"X" : vars2merge },
@@ -419,41 +397,42 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
419
397
break
420
398
if not param_block :
421
399
return
422
- tmpvar = program . global_block () .create_var (
400
+ tmpvar = pserver_block .create_var (
423
401
name = param_block .name ,
424
402
persistable = True ,
425
403
dtype = param_block .dtype ,
426
404
shape = param_block .shape )
427
-
428
405
new_inputs [key ] = tmpvar
429
406
elif key == "LearningRate" :
430
407
# leraning rate variable has already be created by non-optimize op,
431
408
# don't create it once again.
432
- new_inputs [key ] = program .global_block ().vars [opt_op .input (key )[
433
- 0 ]]
409
+ new_inputs [key ] = pserver_block .vars [opt_op .input (key )[0 ]]
434
410
435
411
for key in opt_op .input_names :
436
412
new_shape = None
437
413
if key in ["Param" , "Grad" , "LearningRate" ]:
438
414
continue
439
- var = program .global_block ().vars [opt_op .input (key )[0 ]]
415
+ var = self . program .global_block ().vars [opt_op .input (key )[0 ]]
440
416
# update accumulator variable shape
441
417
param_shape = new_inputs ["Param" ].shape
442
418
new_shape = self ._get_optimizer_input_shape (opt_op .type , key ,
443
419
var .shape , param_shape )
444
- tmpvar = program . global_block () .create_var (
420
+ tmpvar = pserver_block .create_var (
445
421
name = var .name ,
446
422
persistable = var .persistable ,
447
423
dtype = var .dtype ,
448
424
shape = new_shape )
449
425
new_inputs [key ] = tmpvar
450
426
451
427
# change output's ParamOut variable
452
- opt_op .outputs ["ParamOut" ] = new_inputs ["Param" ]
428
+ outputs = self ._get_output_map_from_op (self .program .global_block ().vars ,
429
+ opt_op )
430
+ outputs ["ParamOut" ] = new_inputs ["Param" ]
431
+
453
432
optimize_block .append_op (
454
433
type = opt_op .type ,
455
434
inputs = new_inputs ,
456
- outputs = opt_op . outputs ,
435
+ outputs = outputs ,
457
436
attrs = opt_op .attrs )
458
437
459
438
def _append_pserver_non_opt_ops (self , optimize_block , opt_op ):
@@ -497,11 +476,12 @@ def _is_op_connected(self, op1, op2):
497
476
# If one op's input is another op's output or
498
477
# one op's output is another op's input, we say
499
478
# the two operator is connected.
500
- op1_input_names = self ._fetch_var_names (op1 .inputs )
501
- op1_output_names = self ._fetch_var_names (op1 .outputs )
479
+ op1_input_names = op1 .desc .input_arg_names ()
480
+ op1_output_names = op1 .desc .output_arg_names ()
481
+
482
+ op2_input_names = op2 .desc .input_arg_names ()
483
+ op2_output_names = op2 .desc .output_arg_names ()
502
484
503
- op2_input_names = self ._fetch_var_names (op2 .inputs )
504
- op2_output_names = self ._fetch_var_names (op2 .outputs )
505
485
if set (op1_output_names ) & set (op2_input_names ) or \
506
486
set (op1_input_names ) & set (op2_output_names ):
507
487
return True
@@ -521,21 +501,21 @@ def _create_ufind(self, optimize_ops):
521
501
def _is_opt_op (self , op ):
522
502
# NOTE: It's a HACK implement.
523
503
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
524
- if op . inputs and op .inputs . has_key ( "Param" ) \
525
- and op . inputs . has_key ( "LearningRate" ) :
504
+ if "Param" in op .input_names and \
505
+ "LearningRate" in op . input_names :
526
506
return True
527
507
return False
528
508
529
509
def _is_opt_op_on_pserver (self , endpoint , op ):
530
510
param_names = [
531
511
p .name for p in self .param_grad_ep_mapping [endpoint ]["params" ]
532
512
]
533
- if op .inputs [ "Param" ]. name in param_names :
513
+ if op .input ( "Param" ) in param_names :
534
514
return True
535
515
else :
536
516
for n in param_names :
537
- param = op .inputs [ "Param" ]. name
538
- if same_or_split_var (n , param ) and n != op . inputs [ "Param" ]. name :
517
+ param = op .input ( "Param" )[ 0 ]
518
+ if same_or_split_var (n , param ) and n != param :
539
519
return True
540
520
return False
541
521
return False
@@ -551,6 +531,8 @@ def get_pserver_program(self, endpoint):
551
531
"""
552
532
# step5
553
533
pserver_program = Program ()
534
+ print ("param mapping on pserver: #### " ,
535
+ self .param_grad_ep_mapping [endpoint ]["params" ])
554
536
for v in self .param_grad_ep_mapping [endpoint ]["params" ]:
555
537
self ._clone_var (pserver_program .global_block (), v )
556
538
for v in self .param_grad_ep_mapping [endpoint ]["grads" ]:
@@ -564,7 +546,6 @@ def get_pserver_program(self, endpoint):
564
546
persistable = True ,
565
547
dtype = v .dtype ,
566
548
shape = v .shape )
567
-
568
549
# step6
569
550
optimize_block = pserver_program .create_block (0 )
570
551
# step 6.1
0 commit comments