Skip to content

Commit 0b8630b

Browse files
author
Yancey
authored
Merge pull request #9897 from Yancey1989/auto_grwon_sparse_table
Auto-grown sparse table
2 parents 6db5309 + f12b3f3 commit 0b8630b

File tree

5 files changed

+204
-8
lines changed

5 files changed

+204
-8
lines changed

paddle/fluid/framework/selected_rows.cc

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,52 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace framework {
1919

20+
struct ReAllocateVisitor {
21+
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
22+
: tensor_(tensor), dims_(dims) {}
23+
24+
template <typename T>
25+
void operator()() const {
26+
framework::Tensor cpu_tensor;
27+
platform::CPUPlace cpu;
28+
T* ptr = cpu_tensor.mutable_data<T>(dims_, cpu);
29+
const T* old_ptr =
30+
tensor_->memory_size() == 0 ? nullptr : tensor_->data<T>();
31+
if (old_ptr != nullptr) {
32+
std::copy(old_ptr, old_ptr + tensor_->numel(), ptr);
33+
}
34+
tensor_->ShareDataWith(cpu_tensor);
35+
}
36+
37+
framework::Tensor* tensor_;
38+
framework::DDim dims_;
39+
};
40+
41+
struct TensorCopyVisitor {
42+
TensorCopyVisitor(framework::Tensor* dst, int64_t dst_offset,
43+
const framework::Tensor src, int64_t src_offset,
44+
int64_t size)
45+
: dst_(dst),
46+
dst_offset_(dst_offset),
47+
src_(src),
48+
src_offset_(src_offset),
49+
size_(size) {}
50+
51+
template <typename T>
52+
void operator()() const {
53+
// TODO(Yancey1989): support other place
54+
platform::CPUPlace cpu;
55+
memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu,
56+
src_.data<T>() + src_offset_, size_ * sizeof(T));
57+
}
58+
59+
framework::Tensor* dst_;
60+
int64_t dst_offset_;
61+
framework::Tensor src_;
62+
int64_t src_offset_;
63+
int64_t size_;
64+
};
65+
2066
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
2167
const platform::DeviceContext& dev_ctx) {
2268
{ // the 1st field, uint32_t version
@@ -69,5 +115,66 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
69115
TensorFromStream(is, selected_rows->mutable_value(), dev_ctx);
70116
}
71117

118+
bool SelectedRows::HasKey(int64_t key) const {
119+
return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false
120+
: true;
121+
}
122+
123+
std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
124+
framework::Tensor* value) const {
125+
PADDLE_ENFORCE(value->IsInitialized(),
126+
"The value tensor should be initialized.");
127+
std::vector<int64_t> non_keys;
128+
int64_t value_width = value_->numel() / value_->dims()[0];
129+
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
130+
"output tensor should have the same shape with table "
131+
"execpt the dims[0].");
132+
133+
for (size_t i = 0; i < keys.size(); ++i) {
134+
int64_t index = Index(keys[i]);
135+
if (index == -1) {
136+
non_keys.push_back(keys[i]);
137+
} else {
138+
framework::VisitDataType(
139+
framework::ToDataType(value_->type()),
140+
TensorCopyVisitor(value, i * value_width, *value_.get(),
141+
index * value_width, value_width));
142+
}
143+
}
144+
return non_keys;
145+
}
146+
147+
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
148+
PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized.");
149+
if (value_->IsInitialized()) {
150+
PADDLE_ENFORCE_EQ(
151+
value.type(), value_->type(),
152+
"The type of the value should be same with the original value");
153+
}
154+
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
155+
"The first dim of value should be 1.");
156+
auto index = Index(key);
157+
bool is_new_key = false;
158+
if (index == -1) {
159+
rows_.push_back(key);
160+
index = rows_.size() - 1;
161+
is_new_key = true;
162+
// whether need to resize the table
163+
if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
164+
auto dims = value_->dims();
165+
dims[0] = (dims[0] + 1) << 1;
166+
framework::VisitDataType(framework::ToDataType(value.type()),
167+
ReAllocateVisitor(value_.get(), dims));
168+
}
169+
}
170+
171+
framework::VisitDataType(
172+
framework::ToDataType(value.type()),
173+
TensorCopyVisitor(value_.get(),
174+
index * value_->numel() / value_->dims()[0], value,
175+
static_cast<int64_t>(0), value.numel()));
176+
return is_new_key;
177+
}
178+
72179
} // namespace framework
73180
} // namespace paddle

paddle/fluid/framework/selected_rows.h

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,33 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <algorithm>
1718
#include <vector>
1819

1920
#include "paddle/fluid/framework/lod_tensor.h"
2021
#include "paddle/fluid/framework/tensor.h"
22+
#include "paddle/fluid/memory/memcpy.h"
2123

2224
namespace paddle {
2325
namespace framework {
2426

2527
class SelectedRows {
28+
/*
29+
* @brief We can use the SelectedRows structure to reproduce a sparse table.
30+
* A sparse table is a key-value structure that the key is an `int64_t`
31+
* number,
32+
* and the value is a Tensor which the first dimension is 0.
33+
* You can use the following interface to operate the sparse table, and you
34+
* can find
35+
* some detail information from the comments of each interface:
36+
*
37+
* HasKey(key), whether the sparse table has the specified key.
38+
* Set(key, value), set a key-value pair into the sparse table.
39+
* Get(keys, value*), get value by given key list and apply it to the given
40+
* value pointer
41+
* with the specified offset.
42+
*
43+
*/
2644
public:
2745
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
2846
: rows_(rows), height_(height) {
@@ -50,12 +68,45 @@ class SelectedRows {
5068

5169
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
5270

53-
/**
54-
* get the index of id in rows
71+
/*
72+
* @brief wheter has the specified key in the table.
73+
*
74+
* @return true if the key is exists.
75+
*/
76+
bool HasKey(int64_t key) const;
77+
78+
/*
79+
* @brief Get value by the key list, if the
80+
*
81+
* @return a list of keys which does not exists in table
82+
*/
83+
std::vector<int64_t> Get(std::vector<int64_t> keys,
84+
framework::Tensor* tensor) const;
85+
86+
/*
87+
* @brief Set a key-value pair into the table.
88+
* This function will double the value memory if it's not engouth.
89+
*
90+
* @note:
91+
* 1. The first dim of the value should be 1
92+
* 2. The value should be initialized and the data type
93+
* should be the same with the table.
94+
*
95+
* @return true if the key is a new one, otherwise false
96+
*
97+
*/
98+
bool Set(int64_t key, const Tensor& value);
99+
100+
/*
101+
* @brief Get the index of key in rows
102+
*
103+
* @return -1 if the key does not exists.
55104
*/
56-
int64_t index(int64_t id) const {
57-
auto it = std::find(rows_.begin(), rows_.end(), id);
58-
PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
105+
int64_t Index(int64_t key) const {
106+
auto it = std::find(rows_.begin(), rows_.end(), key);
107+
if (it == rows_.end()) {
108+
return static_cast<int64_t>(-1);
109+
}
59110
return static_cast<int64_t>(std::distance(rows_.begin(), it));
60111
}
61112

paddle/fluid/framework/selected_rows_test.cc

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace framework {
1717

1818
class SelectedRowsTester : public ::testing::Test {
1919
public:
20-
virtual void SetUp() override {
20+
void SetUp() override {
2121
std::vector<int64_t> rows{0, 4, 7};
2222
int64_t height = 10;
2323
int64_t row_numel = 100;
@@ -59,5 +59,40 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
5959
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
6060
}
6161

62+
TEST_F(SelectedRowsTester, Table) {
63+
platform::CPUPlace cpu;
64+
SelectedRows table;
65+
// initialize a sparse table
66+
table.mutable_value()->Resize(framework::make_ddim({1, 100}));
67+
table.mutable_value()->mutable_data<float>(cpu);
68+
table.mutable_rows()->push_back(1);
69+
70+
int64_t key = 10000;
71+
int64_t non_key = 999;
72+
framework::Tensor value;
73+
value.Resize(framework::make_ddim({1, 100}));
74+
auto ptr = value.mutable_data<float>(cpu);
75+
ptr[0] = static_cast<float>(10);
76+
77+
ASSERT_EQ(table.rows().size(), static_cast<size_t>(1));
78+
ASSERT_EQ(table.HasKey(key), false);
79+
80+
table.Set(key, value);
81+
82+
ASSERT_EQ(table.rows().size(), static_cast<size_t>(2));
83+
ASSERT_EQ(table.HasKey(key), true);
84+
// check re-allocate
85+
ASSERT_EQ(table.value().dims()[0], static_cast<int64_t>(4));
86+
87+
framework::Tensor get_value;
88+
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
89+
std::vector<int64_t> keys({non_key, key});
90+
auto non_keys = table.Get(keys, &get_value);
91+
92+
ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
93+
ASSERT_EQ(non_keys.size(), static_cast<size_t>(1));
94+
ASSERT_EQ(non_keys[0], non_key);
95+
}
96+
6297
} // namespace framework
6398
} // namespace paddle

paddle/fluid/operators/lookup_table_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
103103
memset(output + i * row_width, 0, row_width * sizeof(T));
104104
} else {
105105
PADDLE_ENFORCE_GE(ids[i], 0);
106-
auto id_index = table_t.index(ids[i]);
106+
auto id_index = table_t.Index(ids[i]);
107+
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
107108
memcpy(output + i * row_width, table + id_index * row_width,
108109
row_width * sizeof(T));
109110
}

paddle/fluid/operators/sgd_op.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
107107
for (size_t i = 0; i < grad.rows().size(); i++) {
108108
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
109109
"Input rows index should less than height");
110-
int64_t id_index = param.index(grad.rows()[i]);
110+
int64_t id_index = param.Index(grad.rows()[i]);
111+
PADDLE_ENFORCE_GE(id_index, static_cast<int64_t>(0),
112+
"id should be in the table");
111113
for (size_t j = 0; j < grad_row_width; j++) {
112114
out_data[id_index * grad_row_width + j] -=
113115
lr[0] * grad_data[i * grad_row_width + j];

0 commit comments

Comments
 (0)