Skip to content

Commit b2af213

Browse files
authored
Merge pull request #14292 from sneaxiy/delete_buggy_selected_rows_functor
Delete buggy selected_rows functor
2 parents e8642c3 + 9518bc8 commit b2af213

File tree

4 files changed

+18
-55
lines changed

4 files changed

+18
-55
lines changed

paddle/fluid/operators/adagrad_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ struct SparseAdagradFunctor<platform::CPUDeviceContext, T> {
119119
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
120120

121121
// 2. m += g_m * g_m
122-
math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func;
123-
auto grad_square = sqare_func(context, grad_merge, grad_merge);
122+
auto grad_square =
123+
SquareSelectedRows<platform::CPUDeviceContext, T>(context, grad_merge);
124124

125125
math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor;
126126
functor(context, grad_square, moment);

paddle/fluid/operators/adagrad_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
8484
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
8585
framework::Vector<int64_t> merge_rows(grad_merge.rows());
8686
// 2. m += g_m * g_m
87-
math::scatter::Mul<platform::CUDADeviceContext, T> sqare_func;
88-
auto grad_square = sqare_func(context, grad_merge, grad_merge);
87+
auto grad_square =
88+
SquareSelectedRows<platform::CUDADeviceContext, T>(context, grad_merge);
8989

9090
math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
9191
functor(context, grad_square, moment);

paddle/fluid/operators/adagrad_op.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ struct SparseAdagradFunctor {
2828
framework::Tensor *moment, framework::Tensor *param);
2929
};
3030

31+
template <typename DeviceContext, typename T>
32+
framework::SelectedRows SquareSelectedRows(
33+
const DeviceContext &context, const framework::SelectedRows &input) {
34+
framework::SelectedRows out;
35+
out.set_rows(input.rows());
36+
out.set_height(input.height());
37+
out.mutable_value()->mutable_data<T>(input.value().dims(),
38+
context.GetPlace());
39+
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
40+
auto e_in = framework::EigenVector<T>::Flatten(input.value());
41+
e_out.device(*context.eigen_device()) = e_in.square();
42+
return out;
43+
}
44+
3145
template <typename DeviceContext, typename T>
3246
class AdagradOpKernel : public framework::OpKernel<T> {
3347
public:

paddle/fluid/operators/math/selected_rows_functor.h

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -88,57 +88,6 @@ struct MergeAdd {
8888
framework::SelectedRows* output);
8989
};
9090

91-
template <typename DeviceContext, typename T>
92-
struct Add {
93-
framework::SelectedRows operator()(const DeviceContext& context,
94-
const framework::SelectedRows& input1,
95-
const framework::SelectedRows& input2) {
96-
framework::SelectedRows out;
97-
out.set_rows(input1.rows());
98-
out.set_height(input1.height());
99-
out.mutable_value()->mutable_data<T>(input1.value().dims(),
100-
context.GetPlace());
101-
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
102-
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
103-
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
104-
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
105-
return out;
106-
}
107-
};
108-
109-
template <typename DeviceContext, typename T>
110-
struct Mul {
111-
// multiply two SelectedRows
112-
framework::SelectedRows operator()(const DeviceContext& context,
113-
const framework::SelectedRows& input1,
114-
const framework::SelectedRows& input2) {
115-
framework::SelectedRows out;
116-
out.set_rows(input1.rows());
117-
out.set_height(input1.height());
118-
out.mutable_value()->mutable_data<T>(input1.value().dims(),
119-
context.GetPlace());
120-
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
121-
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
122-
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
123-
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
124-
return out;
125-
}
126-
// multiply scalar to SelectedRows
127-
framework::SelectedRows operator()(const DeviceContext& context,
128-
const framework::SelectedRows& input1,
129-
const T input2) {
130-
framework::SelectedRows out;
131-
out.set_rows(input1.rows());
132-
out.set_height(input1.height());
133-
out.mutable_value()->mutable_data<T>(input1.value().dims(),
134-
context.GetPlace());
135-
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
136-
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
137-
e_out.device(*context.eigen_device()) = input2 * e_in1;
138-
return out;
139-
}
140-
};
141-
14291
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
14392

14493
// out = seleted_rows_in / tensor

0 commit comments

Comments
 (0)