Skip to content

Commit ab953ba

Browse files
authored
Merge pull request #10973 from jacquesqiao/fix-prefetch
Fix and optimize async distribute lookup table
2 parents 38af7bc + 0858a50 commit ab953ba

File tree

8 files changed

+83
-59
lines changed

8 files changed

+83
-59
lines changed

paddle/fluid/framework/selected_rows.cc

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const {
121121
}
122122

123123
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
124-
std::vector<int64_t> keys, framework::Tensor* value) const {
124+
const std::vector<int64_t>& keys, framework::Tensor* value) const {
125125
PADDLE_ENFORCE(value->IsInitialized(),
126126
"The value tensor should be initialized.");
127127
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
128-
int64_t value_width = value_->numel() / value_->dims()[0];
129-
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
130-
"output tensor should have the same shape with table "
131-
"execpt the dims[0].");
132-
133-
for (size_t i = 0; i < keys.size(); ++i) {
134-
int64_t index = Index(keys[i]);
135-
if (index == -1) {
136-
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
137-
} else {
138-
framework::VisitDataType(
139-
framework::ToDataType(value_->type()),
140-
TensorCopyVisitor(value, i * value_width, *value_.get(),
141-
index * value_width, value_width));
128+
if (keys.empty()) {
129+
VLOG(3) << "keys is empty, please check data!";
130+
} else {
131+
int64_t value_width = value_->numel() / value_->dims()[0];
132+
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
133+
"output tensor should have the same shape with table "
134+
"except the dims[0].");
135+
136+
for (size_t i = 0; i < keys.size(); ++i) {
137+
int64_t index = Index(keys[i]);
138+
if (index == -1) {
139+
non_keys_pair.push_back(
140+
std::make_pair(keys[i], static_cast<int64_t>(i)));
141+
} else {
142+
framework::VisitDataType(
143+
framework::ToDataType(value_->type()),
144+
TensorCopyVisitor(value, i * value_width, *value_.get(),
145+
index * value_width, value_width));
146+
}
142147
}
143148
}
144149
return non_keys_pair;

paddle/fluid/framework/selected_rows.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class SelectedRows {
8282
* @return a list of pair which contains the non-exists key and the index in
8383
* the value
8484
*/
85-
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys,
85+
std::vector<std::pair<int64_t, int64_t>> Get(const std::vector<int64_t>& keys,
8686
framework::Tensor* value) const;
8787

8888
/*

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase {
177177
program_(program),
178178
prefetch_ctx_(prefetch_ctx),
179179
req_id_(req_id) {
180-
if (sync_mode_) {
181-
request_.reset(new VariableResponse(scope, dev_ctx_, false));
182-
} else {
183-
request_.reset(new VariableResponse(scope, dev_ctx_, true));
184-
}
180+
// prefetch always create a new sub scope
181+
request_.reset(new VariableResponse(scope, dev_ctx_, true));
185182
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
186183
service_->RequestAsyncUnary(
187184
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase {
198195
std::string var_name = request_->OutVarname();
199196
VLOG(3) << "RequestPrefetch " << var_name;
200197
auto var_desc = program_->Block(0).FindVar(var_name);
201-
framework::Scope* local_scope = &scope_->NewScope();
198+
framework::Scope* local_scope = request_->GetMutableLocalScope();
202199
auto* var = local_scope->FindVar(var_name);
203200
InitializeVariable(var, var_desc->GetType());
204-
executor_->RunPreparedContext(prefetch_ctx_, scope_);
201+
executor_->RunPreparedContext(prefetch_ctx_, local_scope);
205202

206203
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
207204

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ static void AsyncUpdateThread(
207207
while (!exit_flag) {
208208
const detail::ReceivedMessage v = queue->Pop();
209209
auto recv_var_name = v.first;
210+
VLOG(4) << "async update " << recv_var_name;
210211
auto var = v.second->GetVar();
211212
if (var == nullptr) {
212213
LOG(ERROR) << "Can not find server side var: " << recv_var_name;

paddle/fluid/operators/lookup_sparse_table_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
127127
.SetDefault(-1.0f);
128128
AddAttr<float>("max",
129129
"(float, default 1.0) "
130-
"Maximun value of uniform random")
130+
"Maximum value of uniform random")
131131
.SetDefault(1.0f);
132132
AddAttr<int>("seed",
133133
"(int, default 0) "

paddle/fluid/operators/sgd_op.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
9696
return;
9797
}
9898

99-
size_t param_row_width = param.value().numel() / param.rows().size();
100-
size_t grad_row_width = grad.value().numel() / grad.rows().size();
99+
auto param_row_width = param.value().dims()[1];
100+
auto grad_row_width = grad.value().dims()[1];
101+
VLOG(4) << " param rows: " << param.rows().size()
102+
<< " param memory rows: " << param.value().dims()[0]
103+
<< " grad rows: " << grad.rows().size()
104+
<< " grad memory rows: " << grad.value().dims()[0];
101105
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
102106
"param_row should have the same size with grad_row");
103107

python/paddle/fluid/framework.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def rename_var(self, name, new_name):
797797
Rename variable in vars and ops' inputs and outputs
798798
"""
799799
if not self.has_var(name):
800-
raise ValueError("var %s is not in current" % name)
800+
raise ValueError("var %s is not in current block" % name)
801801
v = self.var(name)
802802
if type(v) == Parameter:
803803
var_type = "Parameter"
@@ -843,6 +843,7 @@ def rename_var(self, name, new_name):
843843
self.vars[new_name] = var
844844
del self.vars[name]
845845
self.sync_with_cpp()
846+
return var
846847

847848
def remove_var(self, name):
848849
self.sync_with_cpp()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,25 @@ def transpile(self,
273273
if param_grad[0].name == self.table_name
274274
][0]
275275
table_grad_var = self.table_param_grad[1]
276-
self.table_grad_list = [
277-
program.global_block().create_var(
278-
name="%s.trainer_%d.pserver_%d" %
279-
(table_grad_var.name, trainer_id, index),
280-
type=table_grad_var.type,
281-
shape=table_grad_var.shape,
282-
dtype=table_grad_var.dtype)
283-
for index in range(len(self.pserver_endpoints))
284-
]
276+
if self.sync_mode:
277+
self.trainer_side_table_grad_list = [
278+
program.global_block().create_var(
279+
name="%s.trainer_%d.pserver_%d" %
280+
(table_grad_var.name, trainer_id, index),
281+
type=table_grad_var.type,
282+
shape=table_grad_var.shape,
283+
dtype=table_grad_var.dtype)
284+
for index in range(len(self.pserver_endpoints))
285+
]
286+
else:
287+
self.trainer_side_table_grad_list = [
288+
program.global_block().create_var(
289+
name="%s.pserver_%d" % (table_grad_var.name, index),
290+
type=table_grad_var.type,
291+
shape=table_grad_var.shape,
292+
dtype=table_grad_var.dtype)
293+
for index in range(len(self.pserver_endpoints))
294+
]
285295

286296
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
287297
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
@@ -400,7 +410,8 @@ def transpile(self,
400410
attrs={"axis": 0})
401411

402412
if self.has_distributed_lookup_table:
403-
self._replace_lookup_table_op_with_prefetch(program, eplist)
413+
self._replace_lookup_table_op_with_prefetch(program,
414+
pserver_endpoints)
404415
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
405416

406417
def get_trainer_program(self):
@@ -537,7 +548,7 @@ def __append_optimize_op__(op, block, grad_to_block_id):
537548
if self.has_distributed_lookup_table:
538549
pserver_index = self.pserver_endpoints.index(endpoint)
539550
table_opt_block = self._create_table_optimize_block(
540-
pserver_index, pserver_program, pre_block_idx)
551+
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
541552
prefetch_block = self._create_prefetch_block(
542553
pserver_index, pserver_program, table_opt_block)
543554

@@ -621,7 +632,8 @@ def _get_splited_name_and_shape(varname):
621632
return s_prog
622633

623634
# transpiler function for dis lookup_table
624-
def _replace_lookup_table_op_with_prefetch(self, program, eplist):
635+
def _replace_lookup_table_op_with_prefetch(self, program,
636+
pserver_endpoints):
625637
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
626638
self.prefetch_input_vars = None
627639
self.prefetch_output_vars = None
@@ -670,7 +682,7 @@ def _replace_lookup_table_op_with_prefetch(self, program, eplist):
670682
inputs={'X': self.prefetch_input_vars},
671683
outputs={"Out": self.prefetch_output_vars},
672684
attrs={
673-
"epmap": eplist,
685+
"epmap": pserver_endpoints,
674686
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
675687
})
676688

@@ -707,11 +719,11 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
707719
inputs={
708720
'Ids': [program.global_block().vars[table_grad_name]]
709721
},
710-
outputs={"Out": self.table_grad_list})
722+
outputs={"Out": self.trainer_side_table_grad_list})
711723
program.global_block().insert_op(
712724
index=op_index + 2,
713725
type="send_vars",
714-
inputs={'X': self.table_grad_list},
726+
inputs={'X': self.trainer_side_table_grad_list},
715727
outputs={},
716728
attrs={
717729
"sync_send": True,
@@ -750,16 +762,7 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
750762
return prefetch_block
751763

752764
def _create_table_optimize_block(self, pserver_index, pserver_program,
753-
pre_block_idx):
754-
def _clone_var(block, var, persistable=True):
755-
assert isinstance(var, Variable)
756-
return block.create_var(
757-
name=var.name,
758-
shape=var.shape,
759-
dtype=var.dtype,
760-
type=var.type,
761-
persistable=persistable)
762-
765+
pre_block_idx, grad_to_block_id):
763766
# STEP: create table optimize block
764767
# create table param and grad var in pserver program
765768
origin_param_var = self.origin_program.global_block().vars[
@@ -770,11 +773,11 @@ def _clone_var(block, var, persistable=True):
770773
dtype=origin_param_var.dtype,
771774
type=core.VarDesc.VarType.SELECTED_ROWS,
772775
persistable=True)
773-
grad_var = _clone_var(
774-
pserver_program.global_block(),
776+
# parameter must be selected rows
777+
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
778+
grad_var = pserver_program.global_block().clone_variable(
775779
self.origin_program.global_block().vars[grad_var_name(
776-
self.table_name)],
777-
persistable=False)
780+
self.table_name)])
778781

779782
# create table optimize block in pserver program
780783
table_opt_op = [
@@ -788,7 +791,7 @@ def _clone_var(block, var, persistable=True):
788791
if self.sync_mode:
789792
# create grad vars in pserver program
790793
table_grad_var = self.table_param_grad[1]
791-
table_grad_list = [
794+
pserver_side_table_grad_list = [
792795
pserver_program.global_block().create_var(
793796
name="%s.trainer_%d.pserver_%d" %
794797
(table_grad_var.name, index, pserver_index),
@@ -798,11 +801,21 @@ def _clone_var(block, var, persistable=True):
798801
for index in range(self.trainer_num)
799802
]
800803

801-
# append sum op for table_grad_list
804+
# append sum op for pserver_side_table_grad_list
802805
table_opt_block.append_op(
803806
type="sum",
804-
inputs={"X": table_grad_list},
807+
inputs={"X": pserver_side_table_grad_list},
805808
outputs={"Out": [grad_var]})
809+
else:
810+
# in async_mode, for table gradient, it also need to be splited to each parameter server
811+
origin_grad_name = grad_var.name
812+
splited_grad_name = self.trainer_side_table_grad_list[
813+
pserver_index].name
814+
if not splited_grad_name.startswith(origin_grad_name):
815+
raise ValueError("origin_grad_var: " + splited_grad_name +
816+
" grad_var:" + grad_var.name)
817+
grad_var = pserver_program.global_block().rename_var(
818+
origin_grad_name, splited_grad_name)
806819

807820
lr_var = pserver_program.global_block().vars[table_opt_op.input(
808821
"LearningRate")[0]]
@@ -818,6 +831,9 @@ def _clone_var(block, var, persistable=True):
818831
outputs=outputs,
819832
attrs=table_opt_op.attrs)
820833

834+
# add table parameter gradient and it's block id to grad_to_block_id
835+
grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))
836+
821837
return table_opt_block
822838

823839
# ====================== private transpiler functions =====================

0 commit comments

Comments
 (0)