Skip to content

Commit e37c9e6

Browse files
authored
Merge pull request #13828 from velconia/accelerate_selected_rows_functor
Accelerate SelectedRows Functors:
2 parents 2562eb9 + 3f6ec90 commit e37c9e6

File tree

4 files changed

+342
-9
lines changed

4 files changed

+342
-9
lines changed

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ add_subdirectory(detail)
33
endif(NOT WIN32)
44

55
function(math_library TARGET)
6-
# math_library is a function to create math library.
7-
# The interface is the same as cc_library.
6+
# math_library is a function to create math library.
7+
# The interface is the same as cc_library.
88
# But it handle split GPU/CPU code and link some common library.
99
set(cc_srcs)
1010
set(cu_srcs)
@@ -53,7 +53,7 @@ cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
5353
math_library(math_function DEPS blas)
5454
math_library(maxouting)
5555
math_library(pooling)
56-
math_library(selected_rows_functor DEPS selected_rows math_function)
56+
math_library(selected_rows_functor DEPS selected_rows math_function blas)
5757
math_library(sequence2batch)
5858
math_library(sequence_padding)
5959
math_library(sequence_pooling DEPS math_function)

paddle/fluid/operators/math/selected_rows_functor.cc

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License. */
1515
#include <set>
1616
#include <vector>
1717

18-
#include "paddle/fluid/operators/math/math_function.h"
1918
#include "paddle/fluid/operators/math/selected_rows_functor.h"
2019

2120
namespace paddle {
@@ -150,6 +149,45 @@ template struct SelectedRowsAddTo<platform::CPUDeviceContext, double>;
150149
template struct SelectedRowsAddTo<platform::CPUDeviceContext, int>;
151150
template struct SelectedRowsAddTo<platform::CPUDeviceContext, int64_t>;
152151

152+
template <typename T>
153+
struct SelectedRowsSumTo<platform::CPUDeviceContext, T> {
154+
void operator()(const platform::CPUDeviceContext& context,
155+
const std::vector<framework::SelectedRows*>& input1,
156+
const std::vector<int64_t>& input2_offsets,
157+
framework::SelectedRows* input2) {
158+
// Ensure all selected rows have the same height
159+
size_t size = 0u;
160+
for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
161+
auto& in_rows = (*iter)->rows();
162+
size += in_rows.end() - in_rows.begin();
163+
auto in1_height = (*iter)->height();
164+
PADDLE_ENFORCE_EQ(in1_height, input2->height());
165+
}
166+
// concat rows
167+
std::vector<int64_t> in2_rows;
168+
in2_rows.reserve(in2_rows.size() + size);
169+
for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
170+
const framework::Vector<int64_t>& in_rows = (*iter)->rows();
171+
in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end());
172+
}
173+
input2->set_rows(in2_rows);
174+
175+
auto* in2_value = input2->mutable_value();
176+
auto* in2_data = in2_value->data<T>();
177+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
178+
size_t offset = 0u;
179+
for (size_t i = 0u; i != input1.size(); ++i) {
180+
auto& in_value = input1[i]->value();
181+
const auto* in_data = in_value.data<T>();
182+
offset += input2_offsets[i];
183+
blas.VCOPY(in_value.numel(), in_data, in2_data + offset);
184+
}
185+
}
186+
};
187+
188+
template struct SelectedRowsSumTo<platform::CPUDeviceContext, float>;
189+
template struct SelectedRowsSumTo<platform::CPUDeviceContext, double>;
190+
153191
template <typename T>
154192
struct SelectedRowsAddToTensor<platform::CPUDeviceContext, T> {
155193
void operator()(const platform::CPUDeviceContext& context,
@@ -208,8 +246,18 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
208246
framework::SelectedRows* output) {
209247
framework::SelectedRows& out = *output;
210248
auto input_rows = input.rows();
211-
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
212-
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
249+
std::vector<int64_t> merge_rows;
250+
merge_rows.reserve(input_rows.size());
251+
std::unordered_map<int64_t, size_t> rows_pos_map;
252+
rows_pos_map.reserve(input_rows.size());
253+
size_t idx = 0u;
254+
for (std::vector<int64_t>::iterator iter = input_rows.begin();
255+
iter != input_rows.end(); ++iter) {
256+
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
257+
rows_pos_map[*iter] = idx++;
258+
merge_rows.emplace_back(*iter);
259+
}
260+
}
213261

214262
auto input_width = input.value().dims()[1];
215263
out.set_rows(merge_rows);
@@ -226,16 +274,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
226274
auto* input_data = input.value().data<T>();
227275

228276
for (size_t i = 0; i < input_rows.size(); i++) {
229-
size_t out_i = FindPos(merge_rows, input_rows[i]);
277+
size_t out_i = rows_pos_map[input_rows[i]];
230278
for (int64_t j = 0; j < input_width; j++) {
231279
out_data[out_i * input_width + j] += input_data[i * input_width + j];
232280
}
233281
}
234282
}
235283
};
236284

237-
template struct MergeAdd<platform::CPUDeviceContext, float>;
238-
template struct MergeAdd<platform::CPUDeviceContext, double>;
239285
template struct MergeAdd<platform::CPUDeviceContext, int>;
240286
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
241287

paddle/fluid/operators/math/selected_rows_functor.h

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#pragma once
15+
16+
#include <vector>
17+
1518
#include "paddle/fluid/framework/eigen.h"
1619
#include "paddle/fluid/framework/selected_rows.h"
20+
#include "paddle/fluid/operators/math/blas.h"
21+
#include "paddle/fluid/operators/math/math_function.h"
1722
#include "paddle/fluid/platform/device_context.h"
1823

1924
#define INLINE_FOR2(sizei, sizej) \
@@ -49,6 +54,15 @@ struct SelectedRowsAddTo {
4954
const int64_t input2_offset, framework::SelectedRows* input2);
5055
};
5156

57+
// input2 = [all input in input1] + input2
58+
template <typename DeviceContext, typename T>
59+
struct SelectedRowsSumTo {
60+
void operator()(const DeviceContext& context,
61+
const std::vector<framework::SelectedRows*>& input1,
62+
const std::vector<int64_t>& input2_offsets,
63+
framework::SelectedRows* input2);
64+
};
65+
5266
// input2 = input1 + input2
5367
template <typename DeviceContext, typename T>
5468
struct SelectedRowsAddToTensor {
@@ -70,6 +84,108 @@ struct MergeAdd {
7084
framework::SelectedRows* output);
7185
};
7286

87+
template <>
88+
struct MergeAdd<platform::CPUDeviceContext, float> {
89+
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
90+
const framework::SelectedRows& input) {
91+
framework::SelectedRows out;
92+
(*this)(context, input, &out);
93+
return out;
94+
}
95+
96+
void operator()(const platform::CPUDeviceContext& context,
97+
const framework::SelectedRows& input,
98+
framework::SelectedRows* output) {
99+
framework::SelectedRows& out = *output;
100+
auto input_rows = input.rows();
101+
std::vector<int64_t> merge_rows;
102+
merge_rows.reserve(input_rows.size());
103+
std::unordered_map<int64_t, size_t> rows_pos_map;
104+
rows_pos_map.reserve(input_rows.size());
105+
size_t idx = 0u;
106+
for (std::vector<int64_t>::iterator iter = input_rows.begin();
107+
iter != input_rows.end(); ++iter) {
108+
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
109+
rows_pos_map[*iter] = idx++;
110+
merge_rows.emplace_back(*iter);
111+
}
112+
}
113+
114+
auto input_width = input.value().dims()[1];
115+
out.set_rows(merge_rows);
116+
out.set_height(input.height());
117+
out.mutable_value()->mutable_data<float>(
118+
framework::make_ddim(
119+
{static_cast<int64_t>(merge_rows.size()), input_width}),
120+
context.GetPlace());
121+
122+
math::SetConstant<platform::CPUDeviceContext, float> constant_functor;
123+
constant_functor(context, out.mutable_value(), 0.0);
124+
125+
auto* out_data = out.mutable_value()->data<float>();
126+
auto* input_data = input.value().data<float>();
127+
128+
auto blas = GetBlas<platform::CPUDeviceContext, float>(context);
129+
for (size_t i = 0; i < input_rows.size(); i++) {
130+
size_t out_i = rows_pos_map[input_rows[i]];
131+
float* y = out_data + out_i * input_width;
132+
const float* x = input_data + i * input_width;
133+
blas.AXPY(input_width, 1., x, y);
134+
}
135+
}
136+
};
137+
138+
template <>
139+
struct MergeAdd<platform::CPUDeviceContext, double> {
140+
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
141+
const framework::SelectedRows& input) {
142+
framework::SelectedRows out;
143+
(*this)(context, input, &out);
144+
return out;
145+
}
146+
147+
void operator()(const platform::CPUDeviceContext& context,
148+
const framework::SelectedRows& input,
149+
framework::SelectedRows* output) {
150+
framework::SelectedRows& out = *output;
151+
auto input_rows = input.rows();
152+
std::vector<int64_t> merge_rows;
153+
merge_rows.reserve(input_rows.size());
154+
std::unordered_map<int64_t, size_t> rows_pos_map;
155+
rows_pos_map.reserve(input_rows.size());
156+
size_t idx = 0u;
157+
for (std::vector<int64_t>::iterator iter = input_rows.begin();
158+
iter != input_rows.end(); ++iter) {
159+
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
160+
rows_pos_map[*iter] = idx++;
161+
merge_rows.emplace_back(*iter);
162+
}
163+
}
164+
165+
auto input_width = input.value().dims()[1];
166+
out.set_rows(merge_rows);
167+
out.set_height(input.height());
168+
out.mutable_value()->mutable_data<double>(
169+
framework::make_ddim(
170+
{static_cast<int64_t>(merge_rows.size()), input_width}),
171+
context.GetPlace());
172+
173+
math::SetConstant<platform::CPUDeviceContext, double> constant_functor;
174+
constant_functor(context, out.mutable_value(), 0.0);
175+
176+
auto* out_data = out.mutable_value()->data<double>();
177+
auto* input_data = input.value().data<double>();
178+
179+
auto blas = GetBlas<platform::CPUDeviceContext, double>(context);
180+
for (size_t i = 0; i < input_rows.size(); i++) {
181+
size_t out_i = rows_pos_map[input_rows[i]];
182+
double* y = out_data + out_i * input_width;
183+
const double* x = input_data + i * input_width;
184+
blas.AXPY(input_width, 1., x, y);
185+
}
186+
}
187+
};
188+
73189
template <typename DeviceContext, typename T>
74190
struct Add {
75191
framework::SelectedRows operator()(const DeviceContext& context,

0 commit comments

Comments
 (0)