Skip to content

Commit 0d3d4ae

Browse files
committed
refine prefetch logic
1 parent 17b42fc commit 0d3d4ae

File tree

3 files changed

+64
-52
lines changed

3 files changed

+64
-52
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
248248
request_prefetch_handler_.get());
249249

250250
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
251-
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
251+
auto grad_to_block_id_str = Attr<std::vector<std::string>>(kPrefetchBlock);
252+
framework::BlockDesc *prefetch_block = nullptr;
252253
auto *program = optimize_block->Program();
253254
framework::Executor executor(dev_place);
254255

@@ -302,8 +303,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
302303
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
303304
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
304305
"BlockID to run on server side.");
305-
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
306-
"prefetch block to run on server side.");
306+
AddAttr<std::vector<std::string>>(kPrefetchBlock,
307+
"prefetch block to run on server side.");
307308
AddAttr<int>("Fanin", "How many clients send to this server.")
308309
.SetDefault(1);
309310
}

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace paddle {
3030
namespace operators {
3131

3232
constexpr char kOptimizeBlock[] = "OptimizeBlock";
33-
constexpr char kPrefetchBlock[] = "PrefetchBlock";
33+
constexpr char kPrefetchBlock[] = "prefetch_var_name_to_block_id";
3434

3535
void RunServer(std::shared_ptr<detail::RPCServer> service);
3636

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -515,21 +515,20 @@ def __op_have_grad_input__(op):
515515
grad_to_block_id, None)
516516

517517
# process distributed lookup_table
518-
prefetch_block = None
518+
prefetch_var_name_to_block_id = []
519519
if self.has_distributed_lookup_table:
520520
pserver_index = self.pserver_endpoints.index(endpoint)
521521
table_opt_block = self._create_table_optimize_block(
522522
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
523-
prefetch_block = self._create_prefetch_block(
523+
prefetch_var_name_to_block_id = self._create_prefetch_block(
524524
pserver_index, pserver_program, table_opt_block)
525525

526526
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
527527
# not be executed, so it's safe to use optimize_block to hold the place
528528
if self.has_distributed_lookup_table:
529-
assert prefetch_block is not None
529+
assert len(prefetch_var_name_to_block_id) > 0
530530
else:
531-
assert prefetch_block is None
532-
prefetch_block = pserver_program.global_block()
531+
assert len(prefetch_var_name_to_block_id) == 0
533532

534533
# step5 append the listen_and_serv op
535534
pserver_program.global_block().append_op(
@@ -540,7 +539,7 @@ def __op_have_grad_input__(op):
540539
"OptimizeBlock": pserver_program.block(1),
541540
"endpoint": endpoint,
542541
"Fanin": self.trainer_num,
543-
"PrefetchBlock": prefetch_block,
542+
"prefetch_var_name_to_block_id": prefetch_var_name_to_block_id,
544543
"sync_mode": self.sync_mode,
545544
"grad_to_block_id": grad_to_block_id
546545
})
@@ -608,8 +607,15 @@ def _get_splited_name_and_shape(varname):
608607
def _replace_lookup_table_op_with_prefetch(self, program,
609608
pserver_endpoints):
610609
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
611-
self.prefetch_input_vars = None
612-
self.prefetch_output_vars = None
610+
# self.all_prefetch_input_vars =
611+
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
612+
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
613+
self.all_prefetch_input_vars = []
614+
615+
# self.all_prefetch_input_vars =
616+
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
617+
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
618+
self.all_prefetch_output_vars = []
613619

614620
continue_search_lookup_table_op = True
615621
while continue_search_lookup_table_op:
@@ -623,18 +629,19 @@ def _replace_lookup_table_op_with_prefetch(self, program,
623629
ids_name = op.input("Ids")
624630
out_name = op.output("Out")
625631

626-
if self.prefetch_input_vars is None:
627-
ids_var = program.global_block().vars[ids_name[0]]
628-
self.prefetch_input_vars = self.create_splited_vars(
629-
source_var=ids_var,
630-
block=program.global_block(),
631-
tag="_prefetch_in_")
632-
if self.prefetch_output_vars is None:
633-
out_var = program.global_block().vars[out_name[0]]
634-
self.prefetch_output_vars = self.create_splited_vars(
635-
source_var=out_var,
636-
block=program.global_block(),
637-
tag="_prefetch_out_")
632+
ids_var = program.global_block().vars[ids_name[0]]
633+
prefetch_input_vars = self.create_splited_vars(
634+
source_var=ids_var,
635+
block=program.global_block(),
636+
tag="_prefetch_in_")
637+
self.all_prefetch_input_vars.append(prefetch_input_vars)
638+
639+
out_var = program.global_block().vars[out_name[0]]
640+
prefetch_output_vars = self.create_splited_vars(
641+
source_var=out_var,
642+
block=program.global_block(),
643+
tag="_prefetch_out_")
644+
self.all_prefetch_output_vars.append(prefetch_output_vars)
638645

639646
# insert split_ids_op
640647
program.global_block().insert_op(
@@ -646,14 +653,14 @@ def _replace_lookup_table_op_with_prefetch(self, program,
646653
for varname in ids_name
647654
]
648655
},
649-
outputs={"Out": self.prefetch_input_vars})
656+
outputs={"Out": prefetch_input_vars})
650657

651658
# insert prefetch_op
652659
program.global_block().insert_op(
653660
index=op_index + 1,
654661
type="prefetch",
655-
inputs={'X': self.prefetch_input_vars},
656-
outputs={"Out": self.prefetch_output_vars},
662+
inputs={'X': prefetch_input_vars},
663+
outputs={"Out": prefetch_output_vars},
657664
attrs={
658665
"epmap": pserver_endpoints,
659666
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
@@ -663,7 +670,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
663670
program.global_block().insert_op(
664671
index=op_index + 2,
665672
type="concat",
666-
inputs={'X': self.prefetch_output_vars},
673+
inputs={'X': prefetch_output_vars},
667674
outputs={
668675
"Out": [
669676
program.global_block().vars[varname]
@@ -709,30 +716,34 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
709716
optimize_block):
710717
# STEP: create prefetch block
711718
table_var = pserver_program.global_block().vars[self.table_name]
712-
prefetch_block = pserver_program.create_block(optimize_block.idx)
713-
trainer_ids = self.prefetch_input_vars[pserver_index]
714-
pserver_ids = pserver_program.global_block().create_var(
715-
name=trainer_ids.name,
716-
type=trainer_ids.type,
717-
shape=trainer_ids.shape,
718-
dtype=trainer_ids.dtype)
719-
trainer_out = self.prefetch_output_vars[pserver_index]
720-
pserver_out = pserver_program.global_block().create_var(
721-
name=trainer_out.name,
722-
type=trainer_out.type,
723-
shape=trainer_out.shape,
724-
dtype=trainer_out.dtype)
725-
prefetch_block.append_op(
726-
type="lookup_sparse_table",
727-
inputs={'Ids': pserver_ids,
728-
"W": table_var},
729-
outputs={"Out": pserver_out},
730-
attrs={
731-
"is_sparse": True, # has no effect on lookup_table op
732-
"is_distributed": True,
733-
"padding_idx": -1
734-
})
735-
return prefetch_block
719+
prefetch_var_name_to_block_id = []
720+
for index in range(len(self.all_prefetch_input_vars)):
721+
prefetch_block = pserver_program.create_block(optimize_block.idx)
722+
trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
723+
pserver_ids = pserver_program.global_block().create_var(
724+
name=trainer_ids.name,
725+
type=trainer_ids.type,
726+
shape=trainer_ids.shape,
727+
dtype=trainer_ids.dtype)
728+
trainer_out = self.all_prefetch_output_vars[index][pserver_index]
729+
pserver_out = pserver_program.global_block().create_var(
730+
name=trainer_out.name,
731+
type=trainer_out.type,
732+
shape=trainer_out.shape,
733+
dtype=trainer_out.dtype)
734+
prefetch_block.append_op(
735+
type="lookup_sparse_table",
736+
inputs={'Ids': pserver_ids,
737+
"W": table_var},
738+
outputs={"Out": pserver_out},
739+
attrs={
740+
"is_sparse": True, # has no effect on lookup_table op
741+
"is_distributed": True,
742+
"padding_idx": -1
743+
})
744+
prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
745+
prefetch_block.idx))
746+
return prefetch_var_name_to_block_id
736747

737748
def _create_table_optimize_block(self, pserver_index, pserver_program,
738749
pre_block_idx, grad_to_block_id):

0 commit comments

Comments
 (0)