Skip to content

Commit 8c1e1de

Browse files
authored
fix fetch handler error with pslib (#20681)
* fix fetch handler error with pslib * fix distributed lookup table op with 1 pserver
1 parent eeaf04d commit 8c1e1de

File tree

2 files changed

+8
-18
lines changed

2 files changed

+8
-18
lines changed

python/paddle/fluid/executor.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -998,18 +998,6 @@ def _run_from_dataset(self,
998998

999999
if fetch_handler is not None:
10001000
fetch_instance = fetch_handler
1001-
elif fetch_handler is None and fetch_list is not None:
1002-
1003-
class FH(FetchHandler):
1004-
def handler(self, fetch_target_vars):
1005-
for i in range(len(fetch_target_vars)):
1006-
print("{}: \n {}\n".format(fetch_info[i],
1007-
fetch_target_vars[i]))
1008-
1009-
fetch_target_names = [var.name for var in fetch_list]
1010-
fetch_instance = FH(fetch_target_names,
1011-
period_secs=print_period,
1012-
return_np=False)
10131001
else:
10141002
fetch_instance = FetchHandler([])
10151003

@@ -1018,7 +1006,10 @@ def handler(self, fetch_target_vars):
10181006
dataset=dataset,
10191007
scope=scope,
10201008
thread=thread,
1021-
debug=debug)
1009+
debug=debug,
1010+
fetch_list=fetch_list,
1011+
fetch_info=fetch_info,
1012+
print_period=print_period)
10221013

10231014
trainer._set_infer(is_infer)
10241015
trainer._gen_trainer_desc()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,8 @@ def transpile(self,
793793
if self.sync_mode:
794794
fetch_barrier_input.extend(splited_var)
795795

796+
self._update_remote_sparse_update_op(program, need_sparse_update_params)
797+
796798
if self.sync_mode:
797799
# form a WAW dependency
798800
program.global_block().append_op(
@@ -806,11 +808,10 @@ def transpile(self,
806808
})
807809

808810
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
809-
if len(splited_var) <= 1:
810-
continue
811811
orig_param = program.global_block().vars[param_varname]
812812
if param_varname not in self.sparse_param_to_height_sections:
813-
if not self.config.runtime_split_send_recv:
813+
if len(splited_var
814+
) > 1 and not self.config.runtime_split_send_recv:
814815
program.global_block().append_op(
815816
type="concat",
816817
inputs={"X": splited_var},
@@ -820,8 +821,6 @@ def transpile(self,
820821
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
821822
})
822823

823-
self._update_remote_sparse_update_op(program,
824-
need_sparse_update_params)
825824
if not self.sync_mode:
826825
lr_ops = self._get_lr_ops()
827826
if len(lr_ops) > 0:

0 commit comments

Comments
 (0)