Skip to content

Commit 229d4bb

Browse files
sneaxiychengduo
authored andcommitted
cherry-pick sparse rmsprop to release/1.0.0 (#13907)
* test=release/1.0.0 * fix sparse rmsprop * test=develop * add check for opt op * test=release/1.0.0
1 parent b97257b commit 229d4bb

20 files changed

+593
-175
lines changed

paddle/fluid/operators/adadelta_op.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace paddle {
1818
namespace operators {
1919

2020
using Tensor = framework::Tensor;
21+
2122
class AdadeltaOp : public framework::OperatorWithKernel {
2223
public:
2324
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel {
3132
"Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
3233
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
3334
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
35+
PADDLE_ENFORCE(
36+
ctx->GetInputsVarType("Param").front() ==
37+
framework::proto::VarType::LOD_TENSOR,
38+
"The input var's type should be LoDTensor, but the received is %s",
39+
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
40+
PADDLE_ENFORCE(
41+
ctx->GetInputsVarType("Grad").front() ==
42+
framework::proto::VarType::LOD_TENSOR,
43+
"The input var's type should be LoDTensor, but the received is %s",
44+
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
3445

3546
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
3647
"Output(ParamOut) of AdadeltaOp should not be null.");
@@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
5667
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
5768
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
5869
}
70+
5971
framework::OpKernelType GetExpectedKernelType(
6072
const framework::ExecutionContext &ctx) const override {
6173
auto input_data_type =

paddle/fluid/operators/adadelta_op.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
2323
class AdadeltaOpKernel : public framework::OpKernel<T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& ctx) const override {
26+
const auto* param_var = ctx.InputVar("Param");
27+
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
28+
"The Var(%s)'s type should be LoDTensor, "
29+
"but the received is %s",
30+
ctx.Inputs("Param").front(), param_var->Type().name());
31+
const auto* grad_var = ctx.InputVar("Grad");
32+
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
33+
"The Var(%s)'s type should be LoDTensor, "
34+
"but the received is %s",
35+
ctx.Inputs("Grad").front(), grad_var->Type().name());
36+
2637
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
2738
auto avg_squared_grad_out_tensor =
2839
ctx.Output<framework::Tensor>("AvgSquaredGradOut");

paddle/fluid/operators/adagrad_op.h

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
1617
#include "paddle/fluid/framework/eigen.h"
1718
#include "paddle/fluid/framework/op_registry.h"
1819

@@ -21,42 +22,48 @@ namespace operators {
2122

2223
template <typename DeviceContext, typename T>
2324
struct SparseAdagradFunctor {
24-
void operator()(const DeviceContext& context,
25-
const framework::SelectedRows& grad,
26-
const framework::Tensor& learning_rate, T epsilon,
27-
framework::Tensor* moment, framework::Tensor* param);
25+
void operator()(const DeviceContext &context,
26+
const framework::SelectedRows &grad,
27+
const framework::Tensor &learning_rate, T epsilon,
28+
framework::Tensor *moment, framework::Tensor *param);
2829
};
2930

3031
template <typename DeviceContext, typename T>
3132
class AdagradOpKernel : public framework::OpKernel<T> {
3233
public:
33-
void Compute(const framework::ExecutionContext& ctx) const override {
34-
auto* param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
35-
auto* moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
34+
void Compute(const framework::ExecutionContext &ctx) const override {
35+
const auto *param_var = ctx.InputVar("Param");
36+
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
37+
"The Var(%s)'s type should be LoDTensor, "
38+
"but the received is %s",
39+
ctx.Inputs("Param").front(), param_var->Type().name());
40+
41+
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
42+
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
3643

3744
param_out_tensor->mutable_data<T>(ctx.GetPlace());
3845
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
3946

4047
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
4148

42-
auto* grad_var = ctx.InputVar("Grad");
49+
auto *grad_var = ctx.InputVar("Grad");
4350
if (grad_var->IsType<framework::LoDTensor>()) {
4451
auto param = framework::EigenVector<T>::Flatten(
4552
*ctx.Input<framework::Tensor>("Param"));
4653
auto grad = framework::EigenVector<T>::Flatten(
4754
*ctx.Input<framework::Tensor>("Grad"));
4855
auto moment = framework::EigenVector<T>::Flatten(
4956
*ctx.Input<framework::Tensor>("Moment"));
50-
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
57+
auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
5158

5259
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
5360
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
54-
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
61+
auto *place = ctx.template device_context<DeviceContext>().eigen_device();
5562

5663
moment_out.device(*place) = moment + grad * grad;
5764
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
5865
if (platform::is_cpu_place(ctx.GetPlace())) {
59-
auto* lr = learning_rate->data<T>();
66+
auto *lr = learning_rate->data<T>();
6067
param_out.device(*place) =
6168
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
6269
} else {
@@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
6673
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
6774
}
6875
} else if (grad_var->IsType<framework::SelectedRows>()) {
69-
auto* param_tensor = ctx.Input<framework::Tensor>("Param");
76+
auto *param_tensor = ctx.Input<framework::Tensor>("Param");
7077
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
7178

72-
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment");
79+
auto *moment_tensor = ctx.Input<framework::Tensor>("Moment");
7380
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
7481

7582
SparseAdagradFunctor<DeviceContext, T> functor;

paddle/fluid/operators/adam_op.h

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <vector>
1919
#include "paddle/fluid/framework/op_registry.h"
2020
#include "paddle/fluid/operators/detail/safe_ref.h"
21+
#include "paddle/fluid/operators/math/algorithm.h"
2122
#include "paddle/fluid/operators/math/selected_rows_functor.h"
2223
#include "paddle/fluid/platform/for_range.h"
2324

@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
199200
row_numel_(row_numel),
200201
row_count_(row_count) {}
201202

202-
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
203-
int64_t beg = 0, end = row_count_ - 1;
204-
while (beg <= end) {
205-
auto mid = ((beg + end) >> 1);
206-
if (rows_[mid] == row)
207-
return mid;
208-
else if (rows_[mid] < row)
209-
beg = mid + 1;
210-
else
211-
end = mid - 1;
212-
}
213-
return -1;
214-
}
215-
216203
inline HOSTDEVICE void operator()(size_t i) const {
217-
int64_t row = i / row_numel_;
218-
auto row_idx = BinarySearchInRows(row);
204+
auto row_idx =
205+
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
219206
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
220207

221208
// The following code is the same as dense
@@ -244,6 +231,12 @@ template <typename DeviceContext, typename T>
244231
class AdamOpKernel : public framework::OpKernel<T> {
245232
public:
246233
void Compute(const framework::ExecutionContext& ctx) const override {
234+
const auto* param_var = ctx.InputVar("Param");
235+
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
236+
"The Var(%s)'s type should be LoDTensor, "
237+
"but the received is %s",
238+
ctx.Inputs("Param").front(), param_var->Type().name());
239+
247240
using paddle::framework::LoDTensor;
248241
using paddle::operators::detail::Ref;
249242

paddle/fluid/operators/adamax_op.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel {
3535
"Input(LearningRate) of AdamaxOp should not be null.");
3636
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
3737
"Input(Beta1Pow) of AdamaxOp should not be null.");
38+
PADDLE_ENFORCE(
39+
ctx->GetInputsVarType("Param").front() ==
40+
framework::proto::VarType::LOD_TENSOR,
41+
"The input var's type should be LoDTensor, but the received is %s",
42+
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
43+
PADDLE_ENFORCE(
44+
ctx->GetInputsVarType("Grad").front() ==
45+
framework::proto::VarType::LOD_TENSOR,
46+
"The input var's type should be LoDTensor, but the received is %s",
47+
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
3848

3949
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
4050
"Output(ParamOut) of AdamaxOp should not be null.");

paddle/fluid/operators/adamax_op.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
2323
class AdamaxOpKernel : public framework::OpKernel<T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& ctx) const override {
26+
const auto* param_var = ctx.InputVar("Param");
27+
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
28+
"The Var(%s)'s type should be LoDTensor, "
29+
"but the received is %s",
30+
ctx.Inputs("Param").front(), param_var->Type().name());
31+
const auto* grad_var = ctx.InputVar("Grad");
32+
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
33+
"The Var(%s)'s type should be LoDTensor, "
34+
"but the received is %s",
35+
ctx.Inputs("Grad").front(), grad_var->Type().name());
36+
2637
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
2738
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
2839
auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut");

paddle/fluid/operators/decayed_adagrad_op.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
3232
PADDLE_ENFORCE(
3333
ctx->HasInput("LearningRate"),
3434
"Input(LearningRate) of DecayedAdagradOp should not be null.");
35+
PADDLE_ENFORCE(
36+
ctx->GetInputsVarType("Param").front() ==
37+
framework::proto::VarType::LOD_TENSOR,
38+
"The input var's type should be LoDTensor, but the received is %s",
39+
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
40+
PADDLE_ENFORCE(
41+
ctx->GetInputsVarType("Grad").front() ==
42+
framework::proto::VarType::LOD_TENSOR,
43+
"The input var's type should be LoDTensor, but the received is %s",
44+
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
3545

3646
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
3747
"Output(ParamOut) of DecayedAdagradOp should not be null.");

paddle/fluid/operators/decayed_adagrad_op.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
2323
class DecayedAdagradOpKernel : public framework::OpKernel<T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& ctx) const override {
26+
const auto* param_var = ctx.InputVar("Param");
27+
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
28+
"The Var(%s)'s type should be LoDTensor, "
29+
"but the received is %s",
30+
ctx.Inputs("Param").front(), param_var->Type().name());
31+
const auto* grad_var = ctx.InputVar("Grad");
32+
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
33+
"The Var(%s)'s type should be LoDTensor, "
34+
"but the received is %s",
35+
ctx.Inputs("Grad").front(), grad_var->Type().name());
36+
2637
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
2738
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
2839

paddle/fluid/operators/ftrl_op.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel {
3434
"Input(Grad) of FTRL should not be null.");
3535
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
3636
"Input(LearningRate) of FTRL should not be null.");
37+
PADDLE_ENFORCE(
38+
ctx->GetInputsVarType("Param").front() ==
39+
framework::proto::VarType::LOD_TENSOR,
40+
"The input var's type should be LoDTensor, but the received is %s",
41+
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
42+
PADDLE_ENFORCE(
43+
ctx->GetInputsVarType("Grad").front() ==
44+
framework::proto::VarType::LOD_TENSOR,
45+
"The input var's type should be LoDTensor, but the received is %s",
46+
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
3747

3848
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
3949
"Output(ParamOut) of FTRL should not be null.");

paddle/fluid/operators/ftrl_op.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ template <typename DeviceContext, typename T>
2828
class FTRLOpKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override {
31+
const auto* param_var = ctx.InputVar("Param");
32+
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
33+
"The Var(%s)'s type should be LoDTensor, "
34+
"but the received is %s",
35+
ctx.Inputs("Param").front(), param_var->Type().name());
36+
const auto* grad_var = ctx.InputVar("Grad");
37+
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
38+
"The Var(%s)'s type should be LoDTensor, "
39+
"but the received is %s",
40+
ctx.Inputs("Grad").front(), grad_var->Type().name());
41+
3142
auto* param_out = ctx.Output<Tensor>("ParamOut");
3243
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
3344
auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut");

0 commit comments

Comments
 (0)