@@ -315,12 +315,22 @@ def transpile(self,
315
315
# step 3.1: insert send op to send gradient vars to parameter servers
316
316
ps_dispatcher .reset ()
317
317
send_vars = []
318
- for varname , splited_vars in grad_var_mapping .items ():
319
- index = find_op_by_output_arg (program .global_block (), varname )
318
+ for orig_varname , splited_vars in grad_var_mapping .items ():
320
319
eplist = ps_dispatcher .dispatch (splited_vars )
321
- if len (splited_vars ) > 1 :
322
- self ._insert_split_op (program , varname , splited_vars )
320
+ if len (splited_vars ) == 1 :
321
+ orig_varname = splited_vars [0 ].name
322
+ index = find_op_by_output_arg (program .global_block (),
323
+ orig_varname )
324
+ elif len (splited_vars ) > 1 :
325
+ orig_var = program .global_block ().vars [orig_varname ]
326
+ index = find_op_by_output_arg (program .global_block (),
327
+ orig_varname )
328
+ self ._insert_split_op (program , orig_var , index , splited_vars )
323
329
index += 1
330
+ else :
331
+ AssertionError ("Can not insert the send op by original "
332
+ "variable name :" , orig_varname )
333
+
324
334
program .global_block ().insert_op (
325
335
index = index + 1 ,
326
336
type = "send_vars" ,
@@ -351,6 +361,12 @@ def transpile(self,
351
361
"RPCClient" : rpc_client_var },
352
362
attrs = {"epmap" : eplist })
353
363
364
+ program .global_block ().append_op (
365
+ type = "fetch_barrier" ,
366
+ inputs = {},
367
+ outputs = {"RPCClient" : rpc_client_var },
368
+ attrs = {"endpoints" : pserver_endpoints })
369
+
354
370
for i , ep in enumerate (eplist ):
355
371
self .param_grad_ep_mapping [ep ]["params" ].append (recv_vars [i ])
356
372
self .param_grad_ep_mapping [ep ]["grads" ].append (send_vars [i ])
@@ -859,9 +875,7 @@ def _clone_var(self, block, var, persistable=True):
859
875
lod_level = var .lod_level ,
860
876
persistable = persistable )
861
877
862
- def _insert_split_op (self , program , orig_varname , splited_vars ):
863
- orig_var = program .global_block ().vars [orig_varname ]
864
- index = find_op_by_output_arg (program .global_block (), orig_varname )
878
+ def _insert_split_op (self , program , orig_var , index , splited_vars ):
865
879
if orig_var .type == core .VarDesc .VarType .SELECTED_ROWS :
866
880
height_sections = []
867
881
for v in splited_vars :
@@ -887,45 +901,6 @@ def _insert_split_op(self, program, orig_varname, splited_vars):
887
901
AssertionError ("Variable type should be in set "
888
902
"[LOD_TENSOR, SELECTED_ROWS]" )
889
903
890
- def _append_split_op (self , program , gradblocks ):
891
- # Split variables that need to be split and append respective ops
892
- add_suffix = False
893
- if self .trainer_num > 1 :
894
- add_suffix = True
895
- var_mapping = self ._create_vars_from_blocklist (
896
- program , gradblocks , add_trainer_suffix = add_suffix )
897
- for varname , splited_vars in var_mapping .iteritems ():
898
- # variable that don't need to split have empty splited_vars
899
- if len (splited_vars ) <= 1 :
900
- continue
901
- orig_var = program .global_block ().vars [varname ]
902
- index = find_op_by_output_arg (program .global_block (), orig_var .name )
903
- if orig_var .type == core .VarDesc .VarType .SELECTED_ROWS :
904
- height_sections = []
905
- for v in splited_vars :
906
- height_sections .append (v .shape [0 ])
907
- program .global_block ().insert_op (
908
- index = index + 1 ,
909
- type = "split_selected_rows" ,
910
- inputs = {"X" : orig_var },
911
- outputs = {"Out" : splited_vars },
912
- attrs = {"height_sections" : height_sections })
913
- elif orig_var .type == core .VarDesc .VarType .LOD_TENSOR :
914
- sections = []
915
- for v in splited_vars :
916
- sections .append (v .shape [0 ])
917
- program .global_block ().insert_op (
918
- index = index + 1 ,
919
- type = "split_byref" ,
920
- inputs = {"X" : orig_var },
921
- outputs = {"Out" : splited_vars },
922
- attrs = {"sections" : sections } # assume split evenly
923
- )
924
- else :
925
- AssertionError ("Variable type should be in set "
926
- "[LOD_TENSOR, SELECTED_ROWS]" )
927
- return var_mapping
928
-
929
904
def _get_optimizer_input_shape (self , op_type , varkey , orig_shape ,
930
905
param_shape ):
931
906
"""
0 commit comments