@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
#pragma once
15
15
16
+ #include < map>
16
17
#include < vector>
17
18
18
19
#include " paddle/fluid/framework/eigen.h"
@@ -97,41 +98,39 @@ struct MergeAdd<platform::CPUDeviceContext, float> {
97
98
const framework::SelectedRows& input,
98
99
framework::SelectedRows* output) {
99
100
framework::SelectedRows& out = *output;
100
- auto input_rows = input.rows ();
101
- std::vector<int64_t > merge_rows;
102
- merge_rows.reserve (input_rows.size ());
103
- std::unordered_map<int64_t , size_t > rows_pos_map;
104
- rows_pos_map.reserve (input_rows.size ());
105
- size_t idx = 0u ;
106
- for (std::vector<int64_t >::iterator iter = input_rows.begin ();
107
- iter != input_rows.end (); ++iter) {
108
- if (rows_pos_map.find (*iter) == rows_pos_map.end ()) {
109
- rows_pos_map[*iter] = idx++;
110
- merge_rows.emplace_back (*iter);
111
- }
101
+ std::vector<int64_t > input_rows (input.rows ());
102
+
103
+ std::map<int64_t , std::vector<int64_t >> merge_row_map;
104
+ for (size_t i = 0 ; i < input_rows.size (); ++i) {
105
+ merge_row_map[input_rows[i]].push_back (i);
112
106
}
113
107
114
- auto input_width = input.value ().dims ()[1 ];
115
- out.set_rows (merge_rows);
108
+ std::vector<int64_t > merge_rows (merge_row_map.size ());
109
+ size_t idx = 0 ;
110
+ int64_t input_width = input.value ().dims ()[1 ];
116
111
out.set_height (input.height ());
117
- out.mutable_value ()->mutable_data <float >(
112
+
113
+ auto * out_data = out.mutable_value ()->mutable_data <float >(
118
114
framework::make_ddim (
119
115
{static_cast <int64_t >(merge_rows.size ()), input_width}),
120
116
context.GetPlace ());
121
-
122
- math::SetConstant<platform::CPUDeviceContext, float > constant_functor;
123
- constant_functor (context, out.mutable_value (), 0.0 );
124
-
125
- auto * out_data = out.mutable_value ()->data <float >();
126
- auto * input_data = input.value ().data <float >();
117
+ auto * in_data = input.value ().data <float >();
127
118
128
119
auto blas = GetBlas<platform::CPUDeviceContext, float >(context);
129
- for (size_t i = 0 ; i < input_rows.size (); i++) {
130
- size_t out_i = rows_pos_map[input_rows[i]];
131
- float * y = out_data + out_i * input_width;
132
- const float * x = input_data + i * input_width;
133
- blas.AXPY (input_width, 1 ., x, y);
120
+ for (auto & row_pair : merge_row_map) {
121
+ auto * out_ptr = out_data + idx * input_width;
122
+ auto & rows = row_pair.second ;
123
+ merge_rows[idx] = row_pair.first ;
124
+ ++idx;
125
+ // rows.size() is always larger than 0
126
+ blas.VCOPY (input_width, in_data + rows[0 ] * input_width, out_ptr);
127
+
128
+ for (size_t i = 1 ; i < rows.size (); ++i) {
129
+ blas.AXPY (input_width, 1 ., in_data + rows[i] * input_width, out_ptr);
130
+ }
134
131
}
132
+
133
+ out.set_rows (merge_rows);
135
134
}
136
135
};
137
136
@@ -148,41 +147,39 @@ struct MergeAdd<platform::CPUDeviceContext, double> {
148
147
const framework::SelectedRows& input,
149
148
framework::SelectedRows* output) {
150
149
framework::SelectedRows& out = *output;
151
- auto input_rows = input.rows ();
152
- std::vector<int64_t > merge_rows;
153
- merge_rows.reserve (input_rows.size ());
154
- std::unordered_map<int64_t , size_t > rows_pos_map;
155
- rows_pos_map.reserve (input_rows.size ());
156
- size_t idx = 0u ;
157
- for (std::vector<int64_t >::iterator iter = input_rows.begin ();
158
- iter != input_rows.end (); ++iter) {
159
- if (rows_pos_map.find (*iter) == rows_pos_map.end ()) {
160
- rows_pos_map[*iter] = idx++;
161
- merge_rows.emplace_back (*iter);
162
- }
150
+ std::vector<int64_t > input_rows (input.rows ());
151
+
152
+ std::map<int64_t , std::vector<int64_t >> merge_row_map;
153
+ for (size_t i = 0 ; i < input_rows.size (); ++i) {
154
+ merge_row_map[input_rows[i]].push_back (i);
163
155
}
164
156
165
- auto input_width = input.value ().dims ()[1 ];
166
- out.set_rows (merge_rows);
157
+ std::vector<int64_t > merge_rows (merge_row_map.size ());
158
+ size_t idx = 0 ;
159
+ int64_t input_width = input.value ().dims ()[1 ];
167
160
out.set_height (input.height ());
168
- out.mutable_value ()->mutable_data <double >(
161
+
162
+ auto * out_data = out.mutable_value ()->mutable_data <double >(
169
163
framework::make_ddim (
170
164
{static_cast <int64_t >(merge_rows.size ()), input_width}),
171
165
context.GetPlace ());
172
-
173
- math::SetConstant<platform::CPUDeviceContext, double > constant_functor;
174
- constant_functor (context, out.mutable_value (), 0.0 );
175
-
176
- auto * out_data = out.mutable_value ()->data <double >();
177
- auto * input_data = input.value ().data <double >();
166
+ auto * in_data = input.value ().data <double >();
178
167
179
168
auto blas = GetBlas<platform::CPUDeviceContext, double >(context);
180
- for (size_t i = 0 ; i < input_rows.size (); i++) {
181
- size_t out_i = rows_pos_map[input_rows[i]];
182
- double * y = out_data + out_i * input_width;
183
- const double * x = input_data + i * input_width;
184
- blas.AXPY (input_width, 1 ., x, y);
169
+ for (auto & row_pair : merge_row_map) {
170
+ auto * out_ptr = out_data + idx * input_width;
171
+ auto & rows = row_pair.second ;
172
+ merge_rows[idx] = row_pair.first ;
173
+ ++idx;
174
+ // rows.size() is always larger than 0
175
+ blas.VCOPY (input_width, in_data + rows[0 ] * input_width, out_ptr);
176
+
177
+ for (size_t i = 1 ; i < rows.size (); ++i) {
178
+ blas.AXPY (input_width, 1 ., in_data + rows[i] * input_width, out_ptr);
179
+ }
185
180
}
181
+
182
+ out.set_rows (merge_rows);
186
183
}
187
184
};
188
185
0 commit comments