Skip to content

Commit ff99d94

Browse files
author
Yancey
authored
Merge pull request #10164 from Yancey1989/lookup_sparse_table_op
add lookup_sparse_table_op
2 parents 1945b72 + 1a93253 commit ff99d94

File tree

10 files changed

+319
-22
lines changed

10 files changed

+319
-22
lines changed

paddle/fluid/framework/lod_tensor_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,11 @@ TEST(LoDTensor, RecordIO) {
255255
std::unique_ptr<std::istream> stream_ptr(stream);
256256
recordio::Scanner scanner(std::move(stream_ptr));
257257
auto tensors = ReadFromRecordIO(&scanner, ctx);
258-
ASSERT_EQ(tensors.size(), 2);
258+
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
259259
assert_tensor_ok(tensors[0]);
260260
assert_tensor_ok(tensors[1]);
261261
tensors = ReadFromRecordIO(&scanner, ctx);
262-
ASSERT_EQ(tensors.size(), 2);
262+
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
263263
assert_tensor_ok(tensors[0]);
264264
assert_tensor_ok(tensors[1]);
265265
}

paddle/fluid/framework/selected_rows.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ bool SelectedRows::HasKey(int64_t key) const {
120120
: true;
121121
}
122122

123-
std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
124-
framework::Tensor* value) const {
123+
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
124+
std::vector<int64_t> keys, framework::Tensor* value) const {
125125
PADDLE_ENFORCE(value->IsInitialized(),
126126
"The value tensor should be initialized.");
127-
std::vector<int64_t> non_keys;
127+
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
128128
int64_t value_width = value_->numel() / value_->dims()[0];
129129
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
130130
"output tensor should have the same shape with table "
@@ -133,15 +133,15 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
133133
for (size_t i = 0; i < keys.size(); ++i) {
134134
int64_t index = Index(keys[i]);
135135
if (index == -1) {
136-
non_keys.push_back(keys[i]);
136+
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
137137
} else {
138138
framework::VisitDataType(
139139
framework::ToDataType(value_->type()),
140140
TensorCopyVisitor(value, i * value_width, *value_.get(),
141141
index * value_width, value_width));
142142
}
143143
}
144-
return non_keys;
144+
return non_keys_pair;
145145
}
146146

147147
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {

paddle/fluid/framework/selected_rows.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <algorithm>
18+
#include <utility>
1819
#include <vector>
1920

2021
#include "paddle/fluid/framework/lod_tensor.h"
@@ -78,10 +79,11 @@ class SelectedRows {
7879
/*
7980
* @brief Get value by the key list, if the
8081
*
81-
* @return a list of keys which does not exists in table
82+
* @return a list of pair which contains the non-exists key and the index in
83+
* the value
8284
*/
83-
std::vector<int64_t> Get(std::vector<int64_t> keys,
84-
framework::Tensor* tensor) const;
85+
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys,
86+
framework::Tensor* value) const;
8587

8688
/*
8789
* @brief Set a key-value pair into the table.

paddle/fluid/framework/selected_rows_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
5959
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
6060
}
6161

62-
TEST_F(SelectedRowsTester, Table) {
62+
TEST_F(SelectedRowsTester, SparseTable) {
6363
platform::CPUPlace cpu;
6464
SelectedRows table;
6565
// initialize a sparse table
@@ -87,11 +87,11 @@ TEST_F(SelectedRowsTester, Table) {
8787
framework::Tensor get_value;
8888
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
8989
std::vector<int64_t> keys({non_key, key});
90-
auto non_keys = table.Get(keys, &get_value);
90+
auto non_key_pairs = table.Get(keys, &get_value);
9191

9292
ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
93-
ASSERT_EQ(non_keys.size(), static_cast<size_t>(1));
94-
ASSERT_EQ(non_keys[0], non_key);
93+
ASSERT_EQ(non_key_pairs.size(), static_cast<size_t>(1));
94+
ASSERT_EQ(non_key_pairs[0].first, non_key);
9595
}
9696

9797
} // namespace framework

paddle/fluid/operators/detail/serde_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
108108
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
109109
}
110110
for (size_t i = 0; i < rows2->size(); ++i) {
111-
EXPECT_EQ(rows_data2[i], i);
111+
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
112112
}
113113
EXPECT_EQ(slr2->height(), 1000);
114114
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
17+
#include "paddle/fluid/framework/data_type.h"
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/math_function.h"
20+
#include "paddle/fluid/platform/device_context.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
constexpr int64_t kNoPadding = -1;
26+
27+
class LookupSparseTableInferShape : public framework::InferShapeBase {
28+
public:
29+
void operator()(framework::InferShapeContext *ctx) const override {
30+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
31+
"Output(Out) of LookupSparseTableOp should not be null.");
32+
auto shape_w = ctx->GetInputDim("W");
33+
auto shape_ids = ctx->GetInputDim("Ids");
34+
shape_w[0] = shape_ids.size();
35+
ctx->SetOutputDim("Out", shape_w);
36+
}
37+
};
38+
39+
class LookupSparseTableOp : public framework::OperatorBase {
40+
public:
41+
using framework::OperatorBase::OperatorBase;
42+
43+
private:
44+
void RunImpl(const framework::Scope &scope,
45+
const platform::Place &dev_place) const override {
46+
auto out_var = scope.FindVar(Output("Out"));
47+
auto w_var = scope.FindVar(Input("W"));
48+
auto ids_var = scope.FindVar(Input("Ids"));
49+
unsigned int seed = static_cast<unsigned int>(Attr<int>("seed"));
50+
float min = Attr<float>("min");
51+
float max = Attr<float>("max");
52+
bool auto_grown_table = Attr<bool>("auto_grown_table");
53+
54+
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
55+
"The type of Out var should be LodTensor.");
56+
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
57+
"The type of W var should be SelectedRows.");
58+
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
59+
"The type of Ids var should be LoDTensor.");
60+
auto &ids_t = ids_var->Get<framework::LoDTensor>();
61+
auto out_t = out_var->GetMutable<framework::LoDTensor>();
62+
auto w_t = w_var->GetMutable<framework::SelectedRows>();
63+
std::vector<int64_t> keys;
64+
keys.resize(ids_t.numel());
65+
for (size_t i = 0; i < ids_t.numel(); ++i) {
66+
keys[i] = ids_t.data<int64_t>()[i];
67+
}
68+
69+
// TODO(Yancey1989): support CUDA Place for the sparse table
70+
platform::CPUPlace cpu;
71+
auto out_shape = w_t->value().dims();
72+
out_shape[0] = keys.size();
73+
out_t->Resize(out_shape);
74+
out_t->mutable_data(cpu, w_t->value().type());
75+
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
76+
framework::proto::VarType::FP32,
77+
"The sparse table only support FP32");
78+
auto non_keys_pair = w_t->Get(keys, out_t);
79+
if (!auto_grown_table) {
80+
PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast<size_t>(0),
81+
"there is some keys does exists in the sparse table.");
82+
}
83+
auto value_shape = w_t->value().dims();
84+
value_shape[0] = 1;
85+
for (const auto &it : non_keys_pair) {
86+
const auto key = it.first;
87+
const auto index = it.second;
88+
framework::Tensor value;
89+
value.Resize(value_shape);
90+
auto data = value.mutable_data<float>(cpu);
91+
92+
std::minstd_rand engine;
93+
engine.seed(seed);
94+
std::uniform_real_distribution<float> dist(min, max);
95+
int64_t size = value.numel();
96+
for (int64_t i = 0; i < size; ++i) {
97+
data[i] = dist(engine);
98+
}
99+
w_t->Set(key, value);
100+
memory::Copy(cpu, out_t->mutable_data<float>(cpu) + index * value.numel(),
101+
cpu, value.data<float>(), value.numel() * sizeof(float));
102+
}
103+
}
104+
};
105+
106+
class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
107+
public:
108+
LookupSparseTableOpMaker(OpProto *proto, OpAttrChecker *op_checker)
109+
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
110+
AddInput("W",
111+
"(SelectedRows) The input represents embedding table, "
112+
"which is a learnable parameter.");
113+
AddInput("Ids",
114+
"(LoDTensor) Ids's type should be LoDTensor"
115+
"THe ids to be looked up in W.");
116+
AddOutput("Out",
117+
"(LoDTensor) The lookup results, which have the "
118+
"same type as W.");
119+
AddAttr<int64_t>("padding_idx",
120+
"(int64, default -1) "
121+
"If the value is -1, it makes no effect to lookup. "
122+
"Otherwise the given value indicates padding the output "
123+
"with zeros whenever lookup encounters it in Ids.")
124+
.SetDefault(kNoPadding);
125+
AddAttr<float>("min",
126+
"(float, default -1.0) "
127+
"Minimum value of uniform random")
128+
.SetDefault(-1.0f);
129+
AddAttr<float>("max",
130+
"(float, default 1.0) "
131+
"Maximun value of uniform random")
132+
.SetDefault(1.0f);
133+
AddAttr<int>("seed",
134+
"(int, default 0) "
135+
"Random seed used for generating samples. "
136+
"0 means use a seed generated by the system."
137+
"Note that if seed is not 0, this operator will always "
138+
"generate the same random numbers every time.")
139+
.SetDefault(0);
140+
AddAttr<bool>("auto_grown_table",
141+
"(bool default false)"
142+
"Whether create new value if for nonexistent key.")
143+
.SetDefault(true);
144+
AddComment(R"DOC(
145+
Lookup Sprase Tablel Operator.
146+
147+
This operator is used to perform lookup on parameter W,
148+
then concatenated into a sparse tensor.
149+
150+
The type of Ids(Input) is SelectedRows, the rows of Ids contains
151+
the ids to be looked up in W;
152+
if the Id is not in the sparse table, this operator will return a
153+
random value and set the value into the table for the next looking up.
154+
155+
)DOC");
156+
}
157+
};
158+
} // namespace operators
159+
} // namespace paddle
160+
161+
namespace ops = paddle::operators;
162+
REGISTER_OPERATOR(lookup_sparse_table, ops::LookupSparseTableOp,
163+
ops::LookupSparseTableInferShape,
164+
ops::LookupSparseTableOpMaker,
165+
paddle::framework::EmptyGradOpMaker);

paddle/fluid/operators/sgd_op.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel {
4848
}
4949
};
5050

51+
class SGDOpInferVarType : public framework::VarTypeInference {
52+
public:
53+
void operator()(const framework::OpDesc& op_desc,
54+
framework::BlockDesc* block) const override {
55+
auto input_var = op_desc.Input("Param")[0];
56+
for (auto& out_var : op_desc.Output("ParamOut")) {
57+
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
58+
framework::proto::VarType::SELECTED_ROWS) {
59+
block->FindRecursiveOrCreateVar(out_var).SetType(
60+
framework::proto::VarType::SELECTED_ROWS);
61+
} else {
62+
block->FindRecursiveOrCreateVar(out_var).SetType(
63+
framework::proto::VarType::LOD_TENSOR);
64+
}
65+
}
66+
}
67+
};
68+
5169
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
5270
public:
5371
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
@@ -74,5 +92,6 @@ This operator implements one step of the stochastic gradient descent algorithm.
7492
} // namespace paddle
7593

7694
namespace ops = paddle::operators;
77-
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
95+
REGISTER_OPERATOR(sgd, ops::SGDOp, ops::SGDOpMaker,
96+
paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType);
7897
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);

paddle/fluid/operators/uniform_random_op.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,31 @@ uniform distribution.
116116
.SetDefault(framework::proto::VarType::FP32);
117117
}
118118
};
119+
120+
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
121+
public:
122+
void operator()(const framework::OpDesc& op_desc,
123+
framework::BlockDesc* block) const override {
124+
auto out_var_name = op_desc.Output("Out").front();
125+
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
126+
framework::proto::VarType::SELECTED_ROWS) {
127+
block->FindRecursiveOrCreateVar(out_var_name)
128+
.SetType(framework::proto::VarType::SELECTED_ROWS);
129+
} else {
130+
block->FindRecursiveOrCreateVar(out_var_name)
131+
.SetType(framework::proto::VarType::LOD_TENSOR);
132+
}
133+
}
134+
};
135+
119136
} // namespace operators
120137
} // namespace paddle
121138

122-
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
123-
paddle::operators::UniformRandomOpMaker);
139+
REGISTER_OPERATOR(uniform_random, paddle::operators::UniformRandomOp,
140+
paddle::operators::UniformRandomOpMaker,
141+
paddle::framework::EmptyGradOpMaker,
142+
paddle::operators::UniformRandomOpVarTypeInference);
143+
124144
REGISTER_OP_CPU_KERNEL(uniform_random,
125145
paddle::operators::CPUUniformRandomKernel<float>,
126146
paddle::operators::CPUUniformRandomKernel<double>);

python/paddle/fluid/distribute_transpiler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
661661
shape=trainer_out.shape,
662662
dtype=trainer_out.dtype)
663663
prefetch_block.append_op(
664-
type=LOOKUP_TABLE_TYPE,
664+
type="lookup_sparse_table",
665665
inputs={'Ids': pserver_ids,
666666
"W": table_var},
667667
outputs={"Out": pserver_out},
@@ -685,9 +685,14 @@ def _clone_var(block, var, persistable=True):
685685

686686
# STEP: create table optimize block
687687
# create table param and grad var in pserver program
688-
param_var = _clone_var(
689-
pserver_program.global_block(),
690-
self.origin_program.global_block().vars[self.table_name])
688+
origin_param_var = self.origin_program.global_block().vars[
689+
self.table_name]
690+
param_var = pserver_program.global_block().create_var(
691+
name=origin_param_var.name,
692+
shape=origin_param_var.shape,
693+
dtype=origin_param_var.dtype,
694+
type=core.VarDesc.VarType.SELECTED_ROWS,
695+
persistable=True)
691696
grad_var = _clone_var(
692697
pserver_program.global_block(),
693698
self.origin_program.global_block().vars[framework.grad_var_name(

0 commit comments

Comments
 (0)