Skip to content

Commit 93606c2

Browse files
authored
Merge pull request #13689 from sneaxiy/sparse_rmsprop
Fix sparse rmsprop
2 parents 681226e + 5cedfb6 commit 93606c2

File tree

6 files changed

+496
-235
lines changed

6 files changed

+496
-235
lines changed

paddle/fluid/operators/adam_op.h

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <vector>
1919
#include "paddle/fluid/framework/op_registry.h"
2020
#include "paddle/fluid/operators/detail/safe_ref.h"
21+
#include "paddle/fluid/operators/math/algorithm.h"
2122
#include "paddle/fluid/operators/math/selected_rows_functor.h"
2223
#include "paddle/fluid/platform/for_range.h"
2324

@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
199200
row_numel_(row_numel),
200201
row_count_(row_count) {}
201202

202-
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
203-
int64_t beg = 0, end = row_count_ - 1;
204-
while (beg <= end) {
205-
auto mid = ((beg + end) >> 1);
206-
if (rows_[mid] == row)
207-
return mid;
208-
else if (rows_[mid] < row)
209-
beg = mid + 1;
210-
else
211-
end = mid - 1;
212-
}
213-
return -1;
214-
}
215-
216203
inline HOSTDEVICE void operator()(size_t i) const {
217-
int64_t row = i / row_numel_;
218-
auto row_idx = BinarySearchInRows(row);
204+
auto row_idx =
205+
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
219206
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
220207

221208
// The following code is the same as dense
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <cstdint> // for int64_t
19+
#include <numeric>
20+
21+
#include "paddle/fluid/platform/hostdevice.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
namespace math {
26+
27+
template <typename T>
28+
HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
29+
int64_t beg = 0, end = num - 1;
30+
while (beg <= end) {
31+
auto mid = ((beg + end) >> 1);
32+
if (x[mid] == val)
33+
return mid;
34+
else if (x[mid] < val)
35+
beg = mid + 1;
36+
else
37+
end = mid - 1;
38+
}
39+
return -1;
40+
}
41+
42+
} // namespace math
43+
} // namespace operators
44+
} // namespace paddle

paddle/fluid/operators/math/selected_rows_functor.cc

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <map>
1516
#include <set>
1617
#include <vector>
1718

19+
#include "paddle/fluid/operators/math/blas.h"
1820
#include "paddle/fluid/operators/math/selected_rows_functor.h"
1921

2022
namespace paddle {
@@ -245,40 +247,42 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
245247
const framework::SelectedRows& input,
246248
framework::SelectedRows* output) {
247249
framework::SelectedRows& out = *output;
248-
auto input_rows = input.rows();
249-
std::vector<int64_t> merge_rows;
250-
merge_rows.reserve(input_rows.size());
251-
std::unordered_map<int64_t, size_t> rows_pos_map;
252-
rows_pos_map.reserve(input_rows.size());
253-
size_t idx = 0u;
254-
for (std::vector<int64_t>::iterator iter = input_rows.begin();
255-
iter != input_rows.end(); ++iter) {
256-
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
257-
rows_pos_map[*iter] = idx++;
258-
merge_rows.emplace_back(*iter);
259-
}
250+
std::vector<int64_t> input_rows(input.rows());
251+
252+
std::map<int64_t, std::vector<int64_t>> merge_row_map;
253+
for (size_t i = 0; i < input_rows.size(); ++i) {
254+
merge_row_map[input_rows[i]].push_back(i);
260255
}
261256

262-
auto input_width = input.value().dims()[1];
263-
out.set_rows(merge_rows);
257+
std::vector<int64_t> merge_rows(merge_row_map.size());
258+
size_t idx = 0;
259+
int64_t input_width = input.value().dims()[1];
264260
out.set_height(input.height());
265-
out.mutable_value()->mutable_data<T>(
261+
262+
T* out_data = out.mutable_value()->mutable_data<T>(
266263
framework::make_ddim(
267264
{static_cast<int64_t>(merge_rows.size()), input_width}),
268265
context.GetPlace());
269-
270-
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
271-
constant_functor(context, out.mutable_value(), 0.0);
272-
273-
auto* out_data = out.mutable_value()->data<T>();
274-
auto* input_data = input.value().data<T>();
275-
276-
for (size_t i = 0; i < input_rows.size(); i++) {
277-
size_t out_i = rows_pos_map[input_rows[i]];
278-
for (int64_t j = 0; j < input_width; j++) {
279-
out_data[out_i * input_width + j] += input_data[i * input_width + j];
266+
const T* in_data = input.value().data<T>();
267+
268+
for (auto& row_pair : merge_row_map) {
269+
auto* out_ptr = out_data + idx * input_width;
270+
auto& rows = row_pair.second;
271+
merge_rows[idx] = row_pair.first;
272+
++idx;
273+
// rows.size() is always larger than 0
274+
std::memcpy(out_ptr, in_data + rows[0] * input_width,
275+
sizeof(T) * input_width);
276+
277+
for (size_t i = 1; i < rows.size(); ++i) {
278+
auto* in_ptr = in_data + rows[i] * input_width;
279+
for (int64_t j = 0; j < input_width; ++j) {
280+
out_ptr[j] += in_ptr[j];
281+
}
280282
}
281283
}
284+
285+
out.set_rows(merge_rows);
282286
}
283287
};
284288

paddle/fluid/operators/math/selected_rows_functor.h

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#pragma once
1515

16+
#include <map>
1617
#include <vector>
1718

1819
#include "paddle/fluid/framework/eigen.h"
@@ -97,41 +98,39 @@ struct MergeAdd<platform::CPUDeviceContext, float> {
9798
const framework::SelectedRows& input,
9899
framework::SelectedRows* output) {
99100
framework::SelectedRows& out = *output;
100-
auto input_rows = input.rows();
101-
std::vector<int64_t> merge_rows;
102-
merge_rows.reserve(input_rows.size());
103-
std::unordered_map<int64_t, size_t> rows_pos_map;
104-
rows_pos_map.reserve(input_rows.size());
105-
size_t idx = 0u;
106-
for (std::vector<int64_t>::iterator iter = input_rows.begin();
107-
iter != input_rows.end(); ++iter) {
108-
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
109-
rows_pos_map[*iter] = idx++;
110-
merge_rows.emplace_back(*iter);
111-
}
101+
std::vector<int64_t> input_rows(input.rows());
102+
103+
std::map<int64_t, std::vector<int64_t>> merge_row_map;
104+
for (size_t i = 0; i < input_rows.size(); ++i) {
105+
merge_row_map[input_rows[i]].push_back(i);
112106
}
113107

114-
auto input_width = input.value().dims()[1];
115-
out.set_rows(merge_rows);
108+
std::vector<int64_t> merge_rows(merge_row_map.size());
109+
size_t idx = 0;
110+
int64_t input_width = input.value().dims()[1];
116111
out.set_height(input.height());
117-
out.mutable_value()->mutable_data<float>(
112+
113+
auto* out_data = out.mutable_value()->mutable_data<float>(
118114
framework::make_ddim(
119115
{static_cast<int64_t>(merge_rows.size()), input_width}),
120116
context.GetPlace());
121-
122-
math::SetConstant<platform::CPUDeviceContext, float> constant_functor;
123-
constant_functor(context, out.mutable_value(), 0.0);
124-
125-
auto* out_data = out.mutable_value()->data<float>();
126-
auto* input_data = input.value().data<float>();
117+
auto* in_data = input.value().data<float>();
127118

128119
auto blas = GetBlas<platform::CPUDeviceContext, float>(context);
129-
for (size_t i = 0; i < input_rows.size(); i++) {
130-
size_t out_i = rows_pos_map[input_rows[i]];
131-
float* y = out_data + out_i * input_width;
132-
const float* x = input_data + i * input_width;
133-
blas.AXPY(input_width, 1., x, y);
120+
for (auto& row_pair : merge_row_map) {
121+
auto* out_ptr = out_data + idx * input_width;
122+
auto& rows = row_pair.second;
123+
merge_rows[idx] = row_pair.first;
124+
++idx;
125+
// rows.size() is always larger than 0
126+
blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr);
127+
128+
for (size_t i = 1; i < rows.size(); ++i) {
129+
blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr);
130+
}
134131
}
132+
133+
out.set_rows(merge_rows);
135134
}
136135
};
137136

@@ -148,41 +147,39 @@ struct MergeAdd<platform::CPUDeviceContext, double> {
148147
const framework::SelectedRows& input,
149148
framework::SelectedRows* output) {
150149
framework::SelectedRows& out = *output;
151-
auto input_rows = input.rows();
152-
std::vector<int64_t> merge_rows;
153-
merge_rows.reserve(input_rows.size());
154-
std::unordered_map<int64_t, size_t> rows_pos_map;
155-
rows_pos_map.reserve(input_rows.size());
156-
size_t idx = 0u;
157-
for (std::vector<int64_t>::iterator iter = input_rows.begin();
158-
iter != input_rows.end(); ++iter) {
159-
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
160-
rows_pos_map[*iter] = idx++;
161-
merge_rows.emplace_back(*iter);
162-
}
150+
std::vector<int64_t> input_rows(input.rows());
151+
152+
std::map<int64_t, std::vector<int64_t>> merge_row_map;
153+
for (size_t i = 0; i < input_rows.size(); ++i) {
154+
merge_row_map[input_rows[i]].push_back(i);
163155
}
164156

165-
auto input_width = input.value().dims()[1];
166-
out.set_rows(merge_rows);
157+
std::vector<int64_t> merge_rows(merge_row_map.size());
158+
size_t idx = 0;
159+
int64_t input_width = input.value().dims()[1];
167160
out.set_height(input.height());
168-
out.mutable_value()->mutable_data<double>(
161+
162+
auto* out_data = out.mutable_value()->mutable_data<double>(
169163
framework::make_ddim(
170164
{static_cast<int64_t>(merge_rows.size()), input_width}),
171165
context.GetPlace());
172-
173-
math::SetConstant<platform::CPUDeviceContext, double> constant_functor;
174-
constant_functor(context, out.mutable_value(), 0.0);
175-
176-
auto* out_data = out.mutable_value()->data<double>();
177-
auto* input_data = input.value().data<double>();
166+
auto* in_data = input.value().data<double>();
178167

179168
auto blas = GetBlas<platform::CPUDeviceContext, double>(context);
180-
for (size_t i = 0; i < input_rows.size(); i++) {
181-
size_t out_i = rows_pos_map[input_rows[i]];
182-
double* y = out_data + out_i * input_width;
183-
const double* x = input_data + i * input_width;
184-
blas.AXPY(input_width, 1., x, y);
169+
for (auto& row_pair : merge_row_map) {
170+
auto* out_ptr = out_data + idx * input_width;
171+
auto& rows = row_pair.second;
172+
merge_rows[idx] = row_pair.first;
173+
++idx;
174+
// rows.size() is always larger than 0
175+
blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr);
176+
177+
for (size_t i = 1; i < rows.size(); ++i) {
178+
blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr);
179+
}
185180
}
181+
182+
out.set_rows(merge_rows);
186183
}
187184
};
188185

0 commit comments

Comments
 (0)