Skip to content

Commit e65cbd3

Browse files
authored
Merge pull request #14387 from jacquesqiao/lookup_sparse_table_add_test_mode
Lookup sparse table add test mode
2 parents 6cf8f24 + 51f3838 commit e65cbd3

File tree

5 files changed

+88
-14
lines changed

5 files changed

+88
-14
lines changed

paddle/fluid/framework/selected_rows.cc

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
6363
int64_t size_;
6464
};
6565

66+
struct TensorFillVisitor {
67+
TensorFillVisitor(framework::Tensor* dst, int64_t dst_offset, int64_t size,
68+
float value)
69+
: dst_(dst), dst_offset_(dst_offset), size_(size) {}
70+
71+
template <typename T>
72+
void apply() const {
73+
// TODO(qiao): support other place
74+
platform::CPUPlace cpu;
75+
auto* tensor_data = dst_->mutable_data<T>(cpu);
76+
auto* start = tensor_data + dst_offset_;
77+
auto* end = start + size_;
78+
std::fill(start, end, static_cast<T>(0.0));
79+
}
80+
81+
framework::Tensor* dst_;
82+
int64_t dst_offset_;
83+
int64_t size_;
84+
};
85+
6686
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
6787
const platform::DeviceContext& dev_ctx) {
6888
{ // the 1st field, uint32_t version
@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
120140
: true;
121141
}
122142

123-
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) {
143+
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
144+
bool is_test) {
145+
if (is_test) {
146+
auto iter = id_to_index_.find(key);
147+
if (iter == id_to_index_.end()) {
148+
return -1;
149+
} else {
150+
return iter->second;
151+
}
152+
}
153+
124154
rwlock_->RDLock();
125155
auto iter = id_to_index_.find(key);
126156
if (iter == id_to_index_.end()) {
@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
172202
}
173203

174204
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
175-
bool auto_grown) {
205+
bool auto_grown, bool is_test) {
176206
PADDLE_ENFORCE(value->IsInitialized(),
177207
"The value tensor should be initialized.");
178208
if (ids.numel() == 0) {
@@ -183,11 +213,19 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
183213
"output tensor should have the same shape with table "
184214
"except the dims[0].");
185215
for (int i = 0; i < ids.numel(); ++i) {
186-
int64_t index = AutoGrownIndex(ids.data<int64_t>()[i], auto_grown);
187-
framework::VisitDataType(
188-
framework::ToDataType(value_->type()),
189-
TensorCopyVisitor(value, i * value_width, *value_.get(),
190-
index * value_width, value_width));
216+
auto id = ids.data<int64_t>()[i];
217+
int64_t index = AutoGrownIndex(id, auto_grown, is_test);
218+
if (index < 0) {
219+
VLOG(5) << "id " << id << " not in the table, return 0";
220+
framework::VisitDataType(
221+
framework::ToDataType(value_->type()),
222+
TensorFillVisitor(value, i * value_width, value_width, 0.0));
223+
} else {
224+
framework::VisitDataType(
225+
framework::ToDataType(value_->type()),
226+
TensorCopyVisitor(value, i * value_width, *value_.get(),
227+
index * value_width, value_width));
228+
}
191229
}
192230
}
193231
}

paddle/fluid/framework/selected_rows.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class SelectedRows {
105105
* the value
106106
*/
107107
void Get(const framework::Tensor& ids, framework::Tensor* value,
108-
bool auto_grown = false);
108+
bool auto_grown = false, bool is_test = false);
109109

110110
/*
111111
* @brief Get the index of the key from id_to_index_ map. If the key not
@@ -118,7 +118,7 @@ class SelectedRows {
118118
*
119119
* @return index of the key.
120120
*/
121-
int64_t AutoGrownIndex(int64_t key, bool auto_grown);
121+
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
122122

123123
void SyncIndex();
124124

paddle/fluid/framework/selected_rows_test.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
8484
data[i * embedding_width + j] = static_cast<float>(i);
8585
}
8686
}
87-
ASSERT_EQ(table.AutoGrownIndex(10, true), 0);
88-
ASSERT_EQ(table.AutoGrownIndex(8, true), 1);
89-
ASSERT_EQ(table.AutoGrownIndex(8, true), 1);
90-
ASSERT_EQ(table.AutoGrownIndex(6, true), 2);
87+
ASSERT_EQ(table.AutoGrownIndex(10, true, false), 0);
88+
ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
89+
ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
90+
ASSERT_EQ(table.AutoGrownIndex(6, true, false), 2);
91+
for (int64_t i = 11; i < 20; i++) {
92+
ASSERT_EQ(table.AutoGrownIndex(i, true, true), -1);
93+
ASSERT_TRUE(!table.HasKey(i));
94+
}
9195
ASSERT_TRUE(table.HasKey(10));
9296
ASSERT_TRUE(table.HasKey(8));
9397
ASSERT_TRUE(table.HasKey(6));

paddle/fluid/operators/lookup_sparse_table_op.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
4545
auto out_var = scope.FindVar(Output("Out"));
4646
auto w_var = scope.FindVar(Input("W"));
4747
auto ids_var = scope.FindVar(Input("Ids"));
48+
auto is_test = Attr<bool>("is_test");
4849

4950
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
5051
"The type of Out var should be LodTensor.");
@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
6566
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
6667
framework::proto::VarType::FP32,
6768
"The sparse table only support FP32");
68-
w_t->Get(ids_t, out_t, true);
69+
w_t->Get(ids_t, out_t, true, is_test);
6970
}
7071
};
7172

@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
9192
"(bool default false)"
9293
"Whether create new value if for nonexistent key.")
9394
.SetDefault(true);
95+
AddAttr<bool>("is_test",
96+
"In test mode, lookup_sparse_table will "
97+
"return a 0 for unknown id")
98+
.SetDefault(false);
9499
AddComment(R"DOC(
95100
Lookup Sprase Tablel Operator.
96101

python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,33 @@ def check_with_place(self, place):
8080
assert (result_array2[3] == w_array[6]).all()
8181
assert (result_array2[4] == w_array[7]).all()
8282

83+
# create and run lookup_table operator
84+
test_lookup_table = Operator(
85+
"lookup_sparse_table",
86+
W='W',
87+
Ids='Ids',
88+
Out='Out',
89+
min=-5.0,
90+
max=10.0,
91+
seed=10,
92+
is_test=True)
93+
94+
ids = scope.var("Ids").get_tensor()
95+
unknown_id = [44, 22, 33]
96+
ids_array2 = np.array([4, 2, 3, 7, 100000] + unknown_id).astype("int64")
97+
ids.set(ids_array2, place)
98+
test_lookup_table.run(scope, place)
99+
100+
result_array2 = np.array(out_tensor)
101+
assert (result_array2[0] == w_array[5]).all()
102+
assert (result_array2[1] == w_array[1]).all()
103+
assert (result_array2[2] == w_array[2]).all()
104+
assert (result_array2[3] == w_array[6]).all()
105+
assert (result_array2[4] == w_array[7]).all()
106+
107+
for i in [5, 6, 7]:
108+
assert np.all(result_array2[i] == 0)
109+
83110
def test_w_is_selected_rows(self):
84111
places = [core.CPUPlace()]
85112
# currently only support CPU

0 commit comments

Comments
 (0)