File tree Expand file tree Collapse file tree 2 files changed +8
-18
lines changed Expand file tree Collapse file tree 2 files changed +8
-18
lines changed Original file line number Diff line number Diff line change @@ -998,18 +998,6 @@ def _run_from_dataset(self,
998
998
999
999
if fetch_handler is not None :
1000
1000
fetch_instance = fetch_handler
1001
- elif fetch_handler is None and fetch_list is not None :
1002
-
1003
- class FH (FetchHandler ):
1004
- def handler (self , fetch_target_vars ):
1005
- for i in range (len (fetch_target_vars )):
1006
- print ("{}: \n {}\n " .format (fetch_info [i ],
1007
- fetch_target_vars [i ]))
1008
-
1009
- fetch_target_names = [var .name for var in fetch_list ]
1010
- fetch_instance = FH (fetch_target_names ,
1011
- period_secs = print_period ,
1012
- return_np = False )
1013
1001
else :
1014
1002
fetch_instance = FetchHandler ([])
1015
1003
@@ -1018,7 +1006,10 @@ def handler(self, fetch_target_vars):
1018
1006
dataset = dataset ,
1019
1007
scope = scope ,
1020
1008
thread = thread ,
1021
- debug = debug )
1009
+ debug = debug ,
1010
+ fetch_list = fetch_list ,
1011
+ fetch_info = fetch_info ,
1012
+ print_period = print_period )
1022
1013
1023
1014
trainer ._set_infer (is_infer )
1024
1015
trainer ._gen_trainer_desc ()
Original file line number Diff line number Diff line change @@ -793,6 +793,8 @@ def transpile(self,
793
793
if self .sync_mode :
794
794
fetch_barrier_input .extend (splited_var )
795
795
796
+ self ._update_remote_sparse_update_op (program , need_sparse_update_params )
797
+
796
798
if self .sync_mode :
797
799
# form a WAW dependency
798
800
program .global_block ().append_op (
@@ -806,11 +808,10 @@ def transpile(self,
806
808
})
807
809
808
810
for param_varname , splited_var in six .iteritems (self .param_var_mapping ):
809
- if len (splited_var ) <= 1 :
810
- continue
811
811
orig_param = program .global_block ().vars [param_varname ]
812
812
if param_varname not in self .sparse_param_to_height_sections :
813
- if not self .config .runtime_split_send_recv :
813
+ if len (splited_var
814
+ ) > 1 and not self .config .runtime_split_send_recv :
814
815
program .global_block ().append_op (
815
816
type = "concat" ,
816
817
inputs = {"X" : splited_var },
@@ -820,8 +821,6 @@ def transpile(self,
820
821
RPC_OP_ROLE_ATTR_NAME : DIST_OP_ROLE_ATTR_VALUE
821
822
})
822
823
823
- self ._update_remote_sparse_update_op (program ,
824
- need_sparse_update_params )
825
824
if not self .sync_mode :
826
825
lr_ops = self ._get_lr_ops ()
827
826
if len (lr_ops ) > 0 :
You can’t perform that action at this time.
0 commit comments