14
14
15
15
from __future__ import print_function
16
16
import framework
17
- from framework import Program , default_main_program , Parameter , Variable
17
+ from framework import Program , default_main_program , default_startup_program , Parameter , Variable
18
18
import optimizer
19
19
from layer_helper import LayerHelper
20
20
from distributed_spliter import *
@@ -97,7 +97,7 @@ def transpile(self,
97
97
parameter servers.
98
98
99
99
:param optimize_ops: op list of optimization, should be the
100
- return value of Optimizer.minimize
100
+ return value of Optimizer.minimize
101
101
:type optimize_ops: list
102
102
:param params_grads: list of tuple(weight, gradient)
103
103
:type params_grads: list
@@ -131,6 +131,7 @@ def transpile(self,
131
131
# 4. append concat_op to trainer to update local weights.
132
132
# 5. create new program for parameter server.
133
133
# 6. create parameter server program by split_method generated endpoint->VarBlock
134
+ # 7. update startup_program, rename variables to variables with trainer_id
134
135
135
136
pserver_endpoints = pservers .split ("," )
136
137
@@ -175,7 +176,6 @@ def transpile(self,
175
176
shape = [0 ])
176
177
177
178
# create send_op
178
- print ("send inputs: " , send_inputs )
179
179
send_op = program .global_block ().append_op (
180
180
type = "send" ,
181
181
inputs = {"X" : send_inputs },
@@ -194,6 +194,15 @@ def transpile(self,
194
194
outputs = {"Out" : [orig_param ]},
195
195
attrs = {"axis" : 0 })
196
196
197
+ # step 7
198
+ startup_prog = default_startup_program ()
199
+ for varname in startup_prog .global_block ().vars .keys ():
200
+ if varname in param_var_mapping and \
201
+ len (param_var_mapping [varname ]) == 1 :
202
+ new_var_name = "%s.trainer_%d" % \
203
+ (varname , self .trainer_id )
204
+ startup_prog .global_block ().rename_var (varname , new_var_name )
205
+
197
206
def _create_vars_from_blocklist (self , program , block_list ):
198
207
# Create respective variables using the block_list
199
208
block_map = dict ()
@@ -210,7 +219,6 @@ def _create_vars_from_blocklist(self, program, block_list):
210
219
new_var_name = "%s.trainer_%d" % \
211
220
(orig_var .name , self .trainer_id )
212
221
program .global_block ().rename_var (varname , new_var_name )
213
- print ("renaming OK..." , varname , new_var_name )
214
222
var_mapping [varname ] = \
215
223
[program .global_block ().var (new_var_name )]
216
224
continue
@@ -377,10 +385,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
377
385
new_inputs = dict ()
378
386
# update param/grad shape first, then other inputs like
379
387
# moment can use the updated shape
380
- print ("mark1" )
381
388
for key in opt_op .input_names :
382
- # print("opt type: ", opt_op.type)
383
- # print("opt op input: ", key)
384
389
if key == "Grad" :
385
390
grad_block = None
386
391
for g in self .param_grad_ep_mapping [endpoint ]["grads" ]:
@@ -427,7 +432,6 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
427
432
428
433
new_inputs [key ] = tmpvar
429
434
430
- print ("mark2" )
431
435
for key in opt_op .input_names :
432
436
if key in ["Param" , "Grad" ]:
433
437
continue
@@ -451,7 +455,6 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
451
455
inputs = new_inputs ,
452
456
outputs = outputs ,
453
457
attrs = opt_op .attrs )
454
- print ("mark3" )
455
458
456
459
def _append_pserver_non_opt_ops (self , optimize_block , opt_op ):
457
460
program = optimize_block .program
@@ -505,8 +508,6 @@ def get_pserver_program(self, endpoint):
505
508
suff_idx = v .name .find (".trainer_" )
506
509
if suff_idx >= 0 :
507
510
orig_var_name = v .name [:suff_idx ]
508
- print ("create variable for program: %s.trainer_%d" %
509
- (orig_var_name , trainer_id ))
510
511
var = pserver_program .global_block ().create_var (
511
512
name = "%s.trainer_%d" % (orig_var_name , trainer_id ),
512
513
persistable = True ,
@@ -517,11 +518,6 @@ def get_pserver_program(self, endpoint):
517
518
optimize_block = pserver_program .create_block (0 )
518
519
# Iterate through the ops and append ops as needed
519
520
for idx , opt_op in enumerate (self .optimize_ops ):
520
- print ("mark0" )
521
- print (opt_op .inputs .keys ())
522
- for v in opt_op .inputs .values ():
523
- print (v .name )
524
- print (v .shape )
525
521
is_op_on_pserver = self ._is_op_on_pserver (endpoint ,
526
522
self .optimize_ops , idx )
527
523
if not is_op_on_pserver :
@@ -547,7 +543,7 @@ def get_pserver_program(self, endpoint):
547
543
# p.name
548
544
# for p in self.param_grad_ep_mapping[endpoint]["grads"]
549
545
# ],
550
- # "Fanin": self.trainers
546
+ "Fanin" : self .trainers
551
547
})
552
548
pserver_program .sync_with_cpp ()
553
549
return pserver_program
0 commit comments