Skip to content

Commit 2409d0f

Browse files
author
chengduo
authored
Refine regularization for selected_rows (#12369)
* refine regularization for selected_rows * clean lookup_table * refine rpc_server_test * temporally disable rpc_server_test * fix rpc_server_test * add unit test
1 parent 85c4912 commit 2409d0f

File tree

10 files changed

+240
-175
lines changed

10 files changed

+240
-175
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ op_library(cos_sim_op DEPS cos_sim_functor)
270270
op_library(parallel_do_op DEPS executor)
271271
op_library(unsqueeze_op DEPS reshape_op)
272272
op_library(squeeze_op DEPS reshape_op)
273+
op_library(extract_rows_op DEPS memory)
273274

274275
if (WITH_GPU)
275276
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/distributed/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ if(WITH_GRPC)
1717
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
1818
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
1919
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
20-
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
21-
cc_test(rpc_server_test SRCS rpc_server_test.cc
22-
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
20+
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
21+
cc_test(rpc_server_test SRCS rpc_server_test.cc
22+
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
2323
return()
2424
endif()
2525

paddle/fluid/operators/distributed/rpc_server_test.cc

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace framework = paddle::framework;
3030
namespace platform = paddle::platform;
3131
namespace distributed = paddle::operators::distributed;
3232

33-
USE_OP(lookup_table);
33+
USE_NO_KERNEL_OP(lookup_sparse_table);
3434

3535
std::unique_ptr<distributed::RPCServer> g_rpc_service;
3636
std::unique_ptr<distributed::RequestHandler> g_req_handler;
@@ -42,13 +42,13 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
4242
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
4343
framework::VariableNameMap output({{"Output", {"out"}}});
4444
auto op = block->AppendOp();
45-
op->SetType("lookup_table");
45+
op->SetType("lookup_sparse_table");
4646
op->SetInput("W", {"w"});
4747
op->SetInput("Ids", {"ids"});
4848
op->SetOutput("Out", {"out"});
4949

5050
auto& out = *root_block->Var("out");
51-
out.SetType(framework::proto::VarType::SELECTED_ROWS);
51+
out.SetType(framework::proto::VarType::LOD_TENSOR);
5252
out.SetShape({10, 10});
5353

5454
return block;
@@ -59,20 +59,19 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
5959
w_var->GetMutable<framework::SelectedRows>();
6060

6161
auto out_var = scope->Var("out");
62-
out_var->GetMutable<framework::SelectedRows>();
62+
out_var->GetMutable<framework::LoDTensor>();
6363

6464
auto ids_var = scope->Var("ids");
65-
ids_var->GetMutable<framework::SelectedRows>();
65+
ids_var->GetMutable<framework::LoDTensor>();
6666
}
6767

6868
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
6969
int64_t rows_numel) {
7070
CreateVarsOnScope(scope, place);
71-
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>();
72-
auto rows = ids_var->mutable_rows();
73-
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2);
74-
ids_var->mutable_value()->Resize({rows_numel, 1});
75-
ids_var->mutable_value()->mutable_data<float>(*place);
71+
auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
72+
int64_t* ids_ptr =
73+
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
74+
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
7675
}
7776

7877
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
@@ -148,11 +147,11 @@ TEST(PREFETCH, CPU) {
148147
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
149148
client->Wait();
150149
auto var = scope.Var(out_var_name);
151-
auto value = var->GetMutable<framework::SelectedRows>()->value();
152-
auto ptr = value.mutable_data<float>(place);
150+
auto value = var->GetMutable<framework::LoDTensor>();
151+
auto ptr = value->mutable_data<float>(place);
153152

154153
for (int64_t i = 0; i < rows_numel; ++i) {
155-
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
154+
EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
156155
}
157156
}
158157

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <string>
16+
#include <vector>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class ExtractRowsOpInferShape : public framework::InferShapeBase {
23+
public:
24+
void operator()(framework::InferShapeContext *ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of ExtractRowsOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of ExtractRowsOp should not be null.");
29+
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X")[0],
30+
framework::proto::VarType::SELECTED_ROWS,
31+
"The type of input(X) must be SelectedRows.");
32+
auto in_dims = ctx->GetInputDim("X");
33+
34+
ctx->SetOutputDim(
35+
"Out", framework::make_ddim(std::vector<int64_t>{in_dims[0], 1}));
36+
}
37+
};
38+
39+
class ExtractRowsOp : public framework::OperatorBase {
40+
public:
41+
ExtractRowsOp(const std::string &type,
42+
const framework::VariableNameMap &inputs,
43+
const framework::VariableNameMap &outputs,
44+
const framework::AttributeMap &attrs)
45+
: framework::OperatorBase(type, inputs, outputs, attrs) {}
46+
47+
private:
48+
void RunImpl(const framework::Scope &scope,
49+
const platform::Place &place) const override {
50+
auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
51+
auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
52+
53+
auto in_rows = in.rows();
54+
auto out_dim = framework::make_ddim(
55+
std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
56+
auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
57+
58+
if (paddle::platform::is_gpu_place(in.place())) {
59+
#ifdef PADDLE_WITH_CUDA
60+
platform::DeviceContextPool &pool =
61+
platform::DeviceContextPool::Instance();
62+
auto *dev_ctx = pool.Get(in.place());
63+
auto src_ptr = in_rows.Data(in.place());
64+
auto stream =
65+
reinterpret_cast<const platform::CUDADeviceContext &>(*dev_ctx)
66+
.stream();
67+
memory::Copy(boost::get<platform::CUDAPlace>(out->place()), dst_ptr,
68+
boost::get<platform::CUDAPlace>(in.place()), src_ptr,
69+
in_rows.size() * sizeof(int64_t), stream);
70+
#else
71+
PADDLE_THROW("Not compiled with CUDA.");
72+
#endif
73+
} else {
74+
memory::Copy(platform::CPUPlace(), dst_ptr, platform::CPUPlace(),
75+
in_rows.data(), in_rows.size() * sizeof(int64_t));
76+
}
77+
}
78+
};
79+
80+
class ExtractRowsOpMaker : public framework::OpProtoAndCheckerMaker {
81+
public:
82+
void Make() override {
83+
AddInput("X",
84+
"(SelectedRows). The input tensor of extract_rows operator,"
85+
" and its type is SelectedRows.");
86+
AddOutput("Out", "(Tensor). The the rows of input(X).");
87+
88+
AddComment(R"DOC(
89+
ExtractRows Operator.
90+
91+
The function of extract_rows_op is extracting the rows from the input(X)
92+
whose type is SelectedRows.
93+
94+
)DOC");
95+
}
96+
};
97+
98+
} // namespace operators
99+
} // namespace paddle
100+
101+
namespace ops = paddle::operators;
102+
REGISTER_OPERATOR(extract_rows, ops::ExtractRowsOp, ops::ExtractRowsOpMaker,
103+
ops::ExtractRowsOpInferShape);

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,15 @@ class LookupTableOp : public framework::OperatorWithKernel {
3333
auto table_dims = ctx->GetInputDim("W");
3434
auto ids_dims = ctx->GetInputDim("Ids");
3535

36-
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
37-
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
38-
// is LoDTensor, this tensor contains the ids to be looked up in W
39-
// and it must be a column vector with rank = 2 while the 2nd dimension
40-
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
41-
// contains the ids to be looked up in W;
42-
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
43-
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
44-
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
45-
}
36+
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
37+
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
4638

4739
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
48-
ctx->ShareLoD("Ids", /*->*/ "Out");
40+
41+
if (ctx->GetOutputsVarType("Out")[0] ==
42+
framework::proto::VarType::LOD_TENSOR) {
43+
ctx->ShareLoD("Ids", /*->*/ "Out");
44+
}
4945
}
5046

5147
protected:
@@ -62,17 +58,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
6258
AddInput("W",
6359
"(Tensor) The input represents embedding tensors, "
6460
"which is a learnable parameter.");
65-
AddInput(
66-
"Ids",
67-
"(Tensor or SelectedRows) Ids's type can be Tensor or "
68-
"SelectedRows, when Ids's type is Tensor, this tensor contains "
69-
"the ids to be looked up in W and it must be a column vector with "
70-
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
71-
"SelectedRows, the rows of Ids contains the ids to be looked up "
72-
"in W.");
73-
AddOutput("Out",
74-
"(Tensor or SelectedRows) The lookup results, which have the "
75-
"same type as W.");
61+
AddInput("Ids",
62+
"An input with type int32 or int64 "
63+
"contains the ids to be looked up in W. "
64+
"Ids must be a column vector with rank = 2. "
65+
"The 2nd dimension size must be 1.");
66+
AddOutput("Out", "The lookup results, which have the same type as W.");
7667
AddAttr<bool>("is_sparse",
7768
"(boolean, default false) "
7869
"Sparse update.")
@@ -90,15 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
9081
Lookup Table Operator.
9182
9283
This operator is used to perform lookups on the parameter W,
93-
then concatenated into a dense or sparse tensor.
94-
95-
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
96-
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
97-
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
98-
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
99-
at this time, Ids can carry the LoD (Level of Details) information, or not, and
100-
the output only shares the LoD information with input Ids.
84+
then concatenated into a dense tensor.
10185
86+
The input Ids can carry the LoD (Level of Details) information,
87+
or not. And the output only shares the LoD information with input Ids.
10288
10389
)DOC");
10490
}

0 commit comments

Comments
 (0)