@@ -12,8 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
#pragma once
15
+
16
+ #include < vector>
17
+
15
18
#include " paddle/fluid/framework/eigen.h"
16
19
#include " paddle/fluid/framework/selected_rows.h"
20
+ #include " paddle/fluid/operators/math/blas.h"
21
+ #include " paddle/fluid/operators/math/math_function.h"
17
22
#include " paddle/fluid/platform/device_context.h"
18
23
19
24
#define INLINE_FOR2 (sizei, sizej ) \
@@ -49,6 +54,15 @@ struct SelectedRowsAddTo {
49
54
const int64_t input2_offset, framework::SelectedRows* input2);
50
55
};
51
56
57
+ // input2 = [all input in input1] + input2
58
+ template <typename DeviceContext, typename T>
59
+ struct SelectedRowsSumTo {
60
+ void operator ()(const DeviceContext& context,
61
+ const std::vector<framework::SelectedRows*>& input1,
62
+ const std::vector<int64_t >& input2_offsets,
63
+ framework::SelectedRows* input2);
64
+ };
65
+
52
66
// input2 = input1 + input2
53
67
template <typename DeviceContext, typename T>
54
68
struct SelectedRowsAddToTensor {
@@ -70,6 +84,108 @@ struct MergeAdd {
70
84
framework::SelectedRows* output);
71
85
};
72
86
87
+ template <>
88
+ struct MergeAdd <platform::CPUDeviceContext, float > {
89
+ framework::SelectedRows operator ()(const platform::CPUDeviceContext& context,
90
+ const framework::SelectedRows& input) {
91
+ framework::SelectedRows out;
92
+ (*this )(context, input, &out);
93
+ return out;
94
+ }
95
+
96
+ void operator ()(const platform::CPUDeviceContext& context,
97
+ const framework::SelectedRows& input,
98
+ framework::SelectedRows* output) {
99
+ framework::SelectedRows& out = *output;
100
+ auto input_rows = input.rows ();
101
+ std::vector<int64_t > merge_rows;
102
+ merge_rows.reserve (input_rows.size ());
103
+ std::unordered_map<int64_t , size_t > rows_pos_map;
104
+ rows_pos_map.reserve (input_rows.size ());
105
+ size_t idx = 0u ;
106
+ for (std::vector<int64_t >::iterator iter = input_rows.begin ();
107
+ iter != input_rows.end (); ++iter) {
108
+ if (rows_pos_map.find (*iter) == rows_pos_map.end ()) {
109
+ rows_pos_map[*iter] = idx++;
110
+ merge_rows.emplace_back (*iter);
111
+ }
112
+ }
113
+
114
+ auto input_width = input.value ().dims ()[1 ];
115
+ out.set_rows (merge_rows);
116
+ out.set_height (input.height ());
117
+ out.mutable_value ()->mutable_data <float >(
118
+ framework::make_ddim (
119
+ {static_cast <int64_t >(merge_rows.size ()), input_width}),
120
+ context.GetPlace ());
121
+
122
+ math::SetConstant<platform::CPUDeviceContext, float > constant_functor;
123
+ constant_functor (context, out.mutable_value (), 0.0 );
124
+
125
+ auto * out_data = out.mutable_value ()->data <float >();
126
+ auto * input_data = input.value ().data <float >();
127
+
128
+ auto blas = GetBlas<platform::CPUDeviceContext, float >(context);
129
+ for (size_t i = 0 ; i < input_rows.size (); i++) {
130
+ size_t out_i = rows_pos_map[input_rows[i]];
131
+ float * y = out_data + out_i * input_width;
132
+ const float * x = input_data + i * input_width;
133
+ blas.AXPY (input_width, 1 ., x, y);
134
+ }
135
+ }
136
+ };
137
+
138
+ template <>
139
+ struct MergeAdd <platform::CPUDeviceContext, double > {
140
+ framework::SelectedRows operator ()(const platform::CPUDeviceContext& context,
141
+ const framework::SelectedRows& input) {
142
+ framework::SelectedRows out;
143
+ (*this )(context, input, &out);
144
+ return out;
145
+ }
146
+
147
+ void operator ()(const platform::CPUDeviceContext& context,
148
+ const framework::SelectedRows& input,
149
+ framework::SelectedRows* output) {
150
+ framework::SelectedRows& out = *output;
151
+ auto input_rows = input.rows ();
152
+ std::vector<int64_t > merge_rows;
153
+ merge_rows.reserve (input_rows.size ());
154
+ std::unordered_map<int64_t , size_t > rows_pos_map;
155
+ rows_pos_map.reserve (input_rows.size ());
156
+ size_t idx = 0u ;
157
+ for (std::vector<int64_t >::iterator iter = input_rows.begin ();
158
+ iter != input_rows.end (); ++iter) {
159
+ if (rows_pos_map.find (*iter) == rows_pos_map.end ()) {
160
+ rows_pos_map[*iter] = idx++;
161
+ merge_rows.emplace_back (*iter);
162
+ }
163
+ }
164
+
165
+ auto input_width = input.value ().dims ()[1 ];
166
+ out.set_rows (merge_rows);
167
+ out.set_height (input.height ());
168
+ out.mutable_value ()->mutable_data <double >(
169
+ framework::make_ddim (
170
+ {static_cast <int64_t >(merge_rows.size ()), input_width}),
171
+ context.GetPlace ());
172
+
173
+ math::SetConstant<platform::CPUDeviceContext, double > constant_functor;
174
+ constant_functor (context, out.mutable_value (), 0.0 );
175
+
176
+ auto * out_data = out.mutable_value ()->data <double >();
177
+ auto * input_data = input.value ().data <double >();
178
+
179
+ auto blas = GetBlas<platform::CPUDeviceContext, double >(context);
180
+ for (size_t i = 0 ; i < input_rows.size (); i++) {
181
+ size_t out_i = rows_pos_map[input_rows[i]];
182
+ double * y = out_data + out_i * input_width;
183
+ const double * x = input_data + i * input_width;
184
+ blas.AXPY (input_width, 1 ., x, y);
185
+ }
186
+ }
187
+ };
188
+
73
189
template <typename DeviceContext, typename T>
74
190
struct Add {
75
191
framework::SelectedRows operator ()(const DeviceContext& context,
0 commit comments