Skip to content

Commit d883547

Browse files
author
kavyasrinet
authored
Adding the FTRL optimizer. (#5785)
* Adding the FTRL optimizer * Fixed the python test case
1 parent 32b10d3 commit d883547

File tree

4 files changed

+316
-0
lines changed

4 files changed

+316
-0
lines changed

paddle/operators/ftrl_op.cc

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/ftrl_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class FTRLOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContext *ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("Param"),
27+
"Input(Param) of FTRL should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("SquaredAccumulator"),
29+
"Input(SquaredAccumulator) of FTRL should not be null.");
30+
PADDLE_ENFORCE(ctx->HasInput("LinearAccumulator"),
31+
"Input(LinearAccumulator) of FTRL should not be null.");
32+
PADDLE_ENFORCE(ctx->HasInput("Grad"),
33+
"Input(Grad) of FTRL should not be null.");
34+
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
35+
"Input(LearningRate) of FTRL should not be null.");
36+
37+
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
38+
"Output(ParamOut) of FTRL should not be null.");
39+
PADDLE_ENFORCE(ctx->HasOutput("SquaredAccumOut"),
40+
"Output(SquaredAccumOut) of FTRL should not be null.");
41+
PADDLE_ENFORCE(ctx->HasOutput("LinearAccumOut"),
42+
"Output(LinearAccumOut) of FTRL should not be null.");
43+
44+
auto param_dim = ctx->GetInputDim("Param");
45+
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
46+
"Two input of FTRL Op's dimension must be same.");
47+
48+
auto lr_dim = ctx->GetInputDim("LearningRate");
49+
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
50+
"Learning Rate should be a scalar.");
51+
52+
ctx->SetOutputDim("ParamOut", param_dim);
53+
ctx->SetOutputDim("SquaredAccumOut", param_dim);
54+
ctx->SetOutputDim("LinearAccumOut", param_dim);
55+
}
56+
};
57+
58+
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
59+
public:
60+
FTRLOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
61+
: OpProtoAndCheckerMaker(proto, op_checker) {
62+
AddInput("Param",
63+
"(Tensor, default Tensor<float>) "
64+
"Input parameter value that has to be updated.");
65+
AddInput("SquaredAccumulator",
66+
"(Tensor, default Tensor<float>) "
67+
"Accumulator that accumulates squared gradients.");
68+
AddInput("LinearAccumulator",
69+
"(Tensor, default Tensor<float>) "
70+
"Accumulator that accumulates linear gradients.");
71+
AddInput("Grad",
72+
"(Tensor, default Tensor<float>) "
73+
"Input gradient of the parameter.");
74+
AddInput("LearningRate",
75+
"(Tensor, default Tensor<float>) "
76+
"The learning rate should be a tensor of size 1.");
77+
78+
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
79+
AddOutput("SquaredAccumOut",
80+
"(Tensor) Output accumulated squared"
81+
" gradients.");
82+
AddOutput("LinearAccumOut",
83+
"(Tensor) Output accumulated linear"
84+
" gradients.");
85+
86+
AddAttr<float>("l1",
87+
"(float, default 0.0) "
88+
"L1 regularization strength.")
89+
.SetDefault(0.0f);
90+
AddAttr<float>("l2",
91+
"(float, default 0.0) "
92+
"L2 regularization strength.")
93+
.SetDefault(0.0f);
94+
AddAttr<float>("lr_power",
95+
"(float, default -0.5f) "
96+
"Learning Rate Power.")
97+
.SetDefault(-0.5f);
98+
AddComment(R"DOC(
99+
FTRL (Follow The Regularized Leader) Operator.
100+
101+
Optimizer that implements the FTRL algorithm:
102+
103+
$$
104+
new\_accum = squared\_accum + grad^2 \\
105+
if (lr\_power == -0.5) {
106+
linear\_accum += grad - (\surd(new\_accum) - \surd(squared\_accum)) /
107+
(learning\_rate * param) \\
108+
} else {
109+
linear\_accum += grad -
110+
(new\_accum^{-lr\_power} - accum^{-lr\_power}) /
111+
(learning\_rate * param) \\
112+
}
113+
114+
x = (l1 * sign(linear\_accum) - linear\_accum)
115+
if (lr\_power == -0.5) {
116+
y = \frac{\surd(new\_accum)}{learning\_rate} + (2 * l2) \\
117+
pre\_shrink = \frac{x}{y} \\
118+
param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\
119+
} else {
120+
y = \frac{new\_accum^{-lr\_power}}{learning\_rate} + (2 * l2) \\
121+
pre\_shrink = \frac{x}{y} \\
122+
param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\
123+
}
124+
squared\_accum += grad^2;
125+
$$
126+
127+
The paper that proposed Follow The Regularized Leader (FTRL):
128+
(https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
129+
130+
)DOC");
131+
}
132+
};
133+
} // namespace operators
134+
} // namespace paddle
135+
136+
namespace ops = paddle::operators;
137+
REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker);
138+
REGISTER_OP_CPU_KERNEL(ftrl,
139+
ops::FTRLOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/ftrl_op.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
You may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed
10+
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
11+
CONDITIONS OF ANY KIND, either express or implied. See the License for the
12+
specific language governing permissions and limitations under the License. */
13+
14+
#define EIGEN_USE_GPU
15+
#include "paddle/operators/ftrl_op.h"
16+
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_GPU_KERNEL(ftrl,
19+
ops::FTRLOpKernel<paddle::platform::GPUPlace, float>);

paddle/operators/ftrl_op.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
template <typename T, int MajorType = Eigen::RowMajor,
24+
typename IndexType = Eigen::DenseIndex>
25+
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
26+
27+
template <typename Place, typename T>
28+
class FTRLOpKernel : public framework::OpKernel<T> {
29+
public:
30+
void Compute(const framework::ExecutionContext& ctx) const override {
31+
auto* param_out = ctx.Output<Tensor>("ParamOut");
32+
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
33+
auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut");
34+
35+
param_out->mutable_data<T>(ctx.GetPlace());
36+
sq_accum_out->mutable_data<T>(ctx.GetPlace());
37+
lin_accum_out->mutable_data<T>(ctx.GetPlace());
38+
39+
auto grad = ctx.Input<Tensor>("Grad");
40+
41+
auto l1 = static_cast<T>(ctx.Attr<float>("l1"));
42+
auto l2 = static_cast<T>(ctx.Attr<float>("l2"));
43+
auto lr_power = static_cast<T>(ctx.Attr<float>("lr_power"));
44+
45+
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
46+
auto sq_accum =
47+
EigenVector<T>::Flatten(*ctx.Input<Tensor>("SquaredAccumulator"));
48+
auto lin_accum =
49+
EigenVector<T>::Flatten(*ctx.Input<Tensor>("LinearAccumulator"));
50+
auto g = EigenVector<T>::Flatten(*grad);
51+
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
52+
53+
auto p_out = EigenVector<T>::Flatten(*param_out);
54+
auto s_acc_out = EigenVector<T>::Flatten(*sq_accum_out);
55+
auto l_acc_out = EigenVector<T>::Flatten(*lin_accum_out);
56+
auto place = ctx.GetEigenDevice<Place>();
57+
58+
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
59+
60+
auto new_accum = sq_accum + g * g;
61+
// Special case for lr_power = -0.5
62+
if (lr_power == static_cast<T>(-0.5)) {
63+
l_acc_out.device(place) =
64+
lin_accum + g -
65+
((new_accum.sqrt() - sq_accum.sqrt()) / lr.broadcast(grad_dsize)) * p;
66+
} else {
67+
l_acc_out.device(place) =
68+
lin_accum + g -
69+
((new_accum.pow(-lr_power) - sq_accum.pow(-lr_power)) /
70+
lr.broadcast(grad_dsize)) *
71+
p;
72+
}
73+
74+
auto x = (l_acc_out.constant(l1) * l_acc_out.sign() - l_acc_out);
75+
if (lr_power == static_cast<T>(-0.5)) {
76+
auto y = (new_accum.sqrt() / lr.broadcast(grad_dsize)) +
77+
l_acc_out.constant(static_cast<T>(2) * l2);
78+
auto pre_shrink = x / y;
79+
p_out.device(place) =
80+
(l_acc_out.abs() > l_acc_out.constant(l1))
81+
.select(pre_shrink, p.constant(static_cast<T>(0)));
82+
} else {
83+
auto y = (new_accum.pow(-lr_power) / lr.broadcast(grad_dsize)) +
84+
l_acc_out.constant(static_cast<T>(2) * l2);
85+
auto pre_shrink = x / y;
86+
p_out.device(place) =
87+
(l_acc_out.abs() > l_acc_out.constant(l1))
88+
.select(pre_shrink, p.constant(static_cast<T>(0)));
89+
}
90+
91+
s_acc_out.device(place) = sq_accum + g * g;
92+
}
93+
};
94+
95+
} // namespace operators
96+
} // namespace paddle
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
class TestFTRLOp(OpTest):
7+
def setUp(self):
8+
self.op_type = "ftrl"
9+
w = np.random.random((102, 105)).astype("float32")
10+
g = np.random.random((102, 105)).astype("float32")
11+
sq_accum = np.full((102, 105), 0.1).astype("float32")
12+
linear_accum = np.full((102, 105), 0.1).astype("float32")
13+
lr = np.array([0.01]).astype("float32")
14+
l1 = 0.1
15+
l2 = 0.2
16+
lr_power = -0.5
17+
18+
self.inputs = {
19+
'Param': w,
20+
'SquaredAccumulator': sq_accum,
21+
'LinearAccumulator': linear_accum,
22+
'Grad': g,
23+
'LearningRate': lr
24+
}
25+
self.attrs = {
26+
'l1': l1,
27+
'l2': l2,
28+
'lr_power': lr_power,
29+
'learning_rate': lr
30+
}
31+
new_accum = sq_accum + g * g
32+
if lr_power == -0.5:
33+
linear_out = linear_accum + g - (
34+
(np.sqrt(new_accum) - np.sqrt(sq_accum)) / lr) * w
35+
else:
36+
linear_out = linear_accum + g - ((np.power(
37+
new_accum, -lr_power) - np.power(sq_accum, -lr_power)) / lr) * w
38+
39+
x = (l1 * np.sign(linear_out) - linear_out)
40+
if lr_power == -0.5:
41+
y = (np.sqrt(new_accum) / lr) + (2 * l2)
42+
pre_shrink = x / y
43+
param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0)
44+
else:
45+
y = (np.power(new_accum, -lr_power) / lr) + (2 * l2)
46+
pre_shrink = x / y
47+
param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0)
48+
49+
sq_accum_out = sq_accum + g * g
50+
51+
self.outputs = {
52+
'ParamOut': param_out,
53+
'SquaredAccumOut': sq_accum_out,
54+
'LinearAccumOut': linear_out
55+
}
56+
57+
def test_check_output(self):
58+
self.check_output()
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)