Skip to content

Commit 735a2db

Browse files
authored
[cherry-pick] add Adam beta1/beta2 support Variable (#21433)
* add Adam beta1/beta2 support Variable. test=develop
1 parent 2660107 commit 735a2db

File tree

11 files changed

+402
-68
lines changed

11 files changed

+402
-68
lines changed

paddle/fluid/operators/optimizers/adam_op.cc

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,50 @@ namespace operators {
2020
using Tensor = framework::Tensor;
2121

2222
void AdamOp::InferShape(framework::InferShapeContext* ctx) const {
23-
PADDLE_ENFORCE(ctx->HasInput("Param"),
24-
"Input(Param) of AdamOp should not be null.");
25-
PADDLE_ENFORCE(ctx->HasInput("Grad"),
26-
"Input(Grad) of AdamOp should not be null.");
27-
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
28-
"Input(Moment1) of AdamOp should not be null.");
29-
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
30-
"Input(Moment2) of AdamOp should not be null.");
31-
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
32-
"Input(LearningRate) of AdamOp should not be null.");
33-
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
34-
"Input(Beta1Pow) of AdamOp should not be null.");
35-
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
36-
"Input(Beta2Pow) of AdamOp should not be null.");
37-
38-
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
39-
"Output(ParamOut) of AdamOp should not be null.");
40-
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
41-
"Output(Moment1Out) of AdamOp should not be null.");
42-
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
43-
"Output(Moment2Out) of AdamOp should not be null.");
23+
PADDLE_ENFORCE_EQ(
24+
ctx->HasInput("Param"), true,
25+
platform::errors::NotFound("Input(Param) of AdamOp should not be null."));
26+
PADDLE_ENFORCE_EQ(
27+
ctx->HasInput("Grad"), true,
28+
platform::errors::NotFound("Input(Grad) of AdamOp should not be null."));
29+
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment1"), true,
30+
platform::errors::NotFound(
31+
"Input(Moment1) of AdamOp should not be null."));
32+
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment2"), true,
33+
platform::errors::NotFound(
34+
"Input(Moment2) of AdamOp should not be null."));
35+
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true,
36+
platform::errors::NotFound(
37+
"Input(LearningRate) of AdamOp should not be null."));
38+
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta1Pow"), true,
39+
platform::errors::NotFound(
40+
"Input(Beta1Pow) of AdamOp should not be null."));
41+
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta2Pow"), true,
42+
platform::errors::NotFound(
43+
"Input(Beta2Pow) of AdamOp should not be null."));
44+
45+
if (ctx->IsRuntime() && ctx->HasInput("Beta1Tensor")) {
46+
auto beta1 = ctx->Inputs("Beta1Tensor");
47+
PADDLE_ENFORCE_EQ(
48+
beta1.size(), 1,
49+
platform::errors::InvalidArgument("Input(Beta1Tensor) size must be 1"));
50+
}
51+
if (ctx->IsRuntime() && ctx->HasInput("Beta2Tensor")) {
52+
auto beta2 = ctx->Inputs("Beta2Tensor");
53+
PADDLE_ENFORCE_EQ(
54+
beta2.size(), 1,
55+
platform::errors::InvalidArgument("Input(Beta2Tensor) size must be 1"));
56+
}
57+
58+
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
59+
platform::errors::NotFound(
60+
"Output(ParamOut) of AdamOp should not be null."));
61+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment1Out"), true,
62+
platform::errors::NotFound(
63+
"Output(Moment1Out) of AdamOp should not be null."));
64+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"), true,
65+
platform::errors::NotFound(
66+
"Output(Moment2Out) of AdamOp should not be null."));
4467

4568
auto lr_dims = ctx->GetInputDim("LearningRate");
4669
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
@@ -93,6 +116,17 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
93116
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
94117
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
95118

119+
AddInput("Beta1Tensor",
120+
"(Tensor<float32>, optional) If provided, Adam will use this "
121+
"as beta1, this has a higher priority than attr(beta1), the "
122+
"shape of this tensor MUST BE [1].")
123+
.AsDispensable();
124+
AddInput("Beta2Tensor",
125+
"(Tensor<float32>, optional) If provided, Adam will use this "
126+
"as beta2, this has a higher priority than attr(beta2), the "
127+
"shape of this tensor MUST BE [1].")
128+
.AsDispensable();
129+
96130
AddOutput("ParamOut", "(Tensor) Output parameter");
97131
AddOutput("Moment1Out", "(Tensor) Output first moment");
98132
AddOutput("Moment2Out", "(Tensor) Output second moment");

paddle/fluid/operators/optimizers/adam_op.h

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ namespace operators {
2929

3030
namespace scatter = paddle::operators::math::scatter;
3131

32+
static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
33+
const float* tensor_data = tensor->data<float>();
34+
framework::Tensor cpu_tensor;
35+
if (platform::is_gpu_place(tensor->place())) {
36+
TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor);
37+
tensor_data = cpu_tensor.data<float>();
38+
}
39+
return tensor_data[0];
40+
}
41+
3242
class AdamOp : public framework::OperatorWithKernel {
3343
public:
3444
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -367,8 +377,6 @@ class AdamOpKernel : public framework::OpKernel<T> {
367377
int64_t min_row_size_to_use_multithread =
368378
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
369379
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
370-
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
371-
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
372380
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
373381
auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param");
374382
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
@@ -390,6 +398,17 @@ class AdamOpKernel : public framework::OpKernel<T> {
390398
auto& mom2_out =
391399
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");
392400

401+
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
402+
if (ctx.HasInput("Beta1Tensor")) {
403+
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
404+
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
405+
}
406+
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
407+
if (ctx.HasInput("Beta2Tensor")) {
408+
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
409+
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
410+
}
411+
393412
if (grad_var->IsType<framework::LoDTensor>()) {
394413
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
395414

paddle/fluid/operators/scale_op.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ class ScaleOp : public framework::OperatorWithKernel {
3434
"Input(X) of ScaleOp should not be null.");
3535
PADDLE_ENFORCE(ctx->HasOutput("Out"),
3636
"Output(Out) of ScaleOp should not be null.");
37+
38+
if (ctx->IsRuntime() && ctx->HasInput("ScaleTensor")) {
39+
auto scale = ctx->Inputs("ScaleTensor");
40+
PADDLE_ENFORCE_EQ(scale.size(), 1,
41+
platform::errors::InvalidArgument(
42+
"Input(ScaleTensor) size must be 1"));
43+
}
44+
3745
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
3846
ctx->ShareLoD("X", /*->*/ "Out");
3947
}
@@ -43,6 +51,11 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
4351
public:
4452
void Make() override {
4553
AddInput("X", "(Tensor) Input tensor of scale operator.");
54+
AddInput("ScaleTensor",
55+
"(Tensor) If provided, use this as "
56+
"scale factor, this has a higher priority than "
57+
"attr(scale), the shape of this tensor MUST BE 1.")
58+
.AsDispensable();
4659
AddOutput("Out", "(Tensor) Output tensor of scale operator.");
4760
AddComment(R"DOC(
4861
**Scale operator**
@@ -89,6 +102,9 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
89102
auto *grad_op = new framework::OpDesc();
90103
grad_op->SetType("scale");
91104
grad_op->SetInput("X", OutputGrad("Out"));
105+
if (ForwardOp().Inputs().count("ScaleTensor") > 0) {
106+
grad_op->SetInput("ScaleTensor", Input("ScaleTensor"));
107+
}
92108
grad_op->SetOutput("Out", InputGrad("X"));
93109
grad_op->SetAttr("scale", GetAttr("scale"));
94110
grad_op->SetAttr("bias", 0.0f);
@@ -97,14 +113,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
97113
}
98114
};
99115

100-
using ScaleOpInplace = framework::SingleOpInplaceInToOut;
116+
DECLARE_INPLACE_OP_INFERER(ScaleOpInplaceInferer, {"X", "Out"});
101117
} // namespace operators
102118
} // namespace paddle
103119

104120
namespace ops = paddle::operators;
105121

106122
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker,
107-
ops::ScaleOpVarTypeInference, ops::ScaleOpInplace);
123+
ops::ScaleOpVarTypeInference, ops::ScaleOpInplaceInferer);
108124
REGISTER_OP_CPU_KERNEL(
109125
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
110126
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,

paddle/fluid/operators/scale_op.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,33 @@ limitations under the License. */
1919

2020
namespace paddle {
2121
namespace operators {
22+
23+
static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
24+
const float* tensor_data = tensor->data<float>();
25+
framework::Tensor cpu_tensor;
26+
if (platform::is_gpu_place(tensor->place())) {
27+
TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor);
28+
tensor_data = cpu_tensor.data<float>();
29+
}
30+
return tensor_data[0];
31+
}
32+
2233
template <typename DeviceContext, typename T>
2334
class ScaleKernel : public framework::OpKernel<T> {
2435
public:
2536
virtual void Compute(const framework::ExecutionContext& ctx) const {
2637
auto* in_var = ctx.InputVar("X");
2738
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
2839

29-
auto scale = static_cast<T>(ctx.Attr<float>("scale"));
3040
auto bias = static_cast<T>(ctx.Attr<float>("bias"));
3141
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
3242

43+
auto scale = static_cast<T>(ctx.Attr<float>("scale"));
44+
if (ctx.HasInput("ScaleTensor")) {
45+
auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor");
46+
scale = GetAttrFromTensor(scale_tensor);
47+
}
48+
3349
auto* out_var = ctx.OutputVar("Out");
3450
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
3551
auto& in_slr = in_var->Get<framework::SelectedRows>();

python/paddle/fluid/layers/layer_function_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def infer_and_check_dtype(op_proto, *args, **kwargs):
174174
if not isinstance(val, list) and not isinstance(val, tuple):
175175
val = [val]
176176
if len(val) == 0:
177+
if len(args) == 0:
178+
continue
177179
val = [args[0]]
178180
args = args[1:]
179181

python/paddle/fluid/layers/nn.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14074,7 +14074,7 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
1407414074

1407514075
Args:
1407614076
x(Variable): Input N-D Tensor of scale operator. Data type can be float32, float64, int8, int16, int32, int64, uint8.
14077-
scale(float): The scale factor of the input.
14077+
scale(float|Variable): The scale factor of the input, it should be a float number or a Variable with shape [1] and data type as float32.
1407814078
bias(float): The bias to be put on the input.
1407914079
bias_after_scale(bool): Apply bias addition after or before scaling. It is useful for numeric stability in some circumstances.
1408014080
act(str, optional): Activation applied to the output such as tanh, softmax, sigmoid, relu.
@@ -14099,6 +14099,27 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
1409914099

1410014100
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
1410114101
print(res) # [array([[ 3., 5., 7.], [ 9., 11., 13.]], dtype=float32)]
14102+
14103+
.. code-block:: python
14104+
14105+
# scale with parameter scale as Variable
14106+
import paddle.fluid as fluid
14107+
import numpy as np
14108+
14109+
inputs = fluid.layers.data(name="x", shape=[2, 3], dtype='float32')
14110+
scale = fluid.layers.data(name="scale", shape=[1], dtype='float32',
14111+
append_batch_size=False)
14112+
output = fluid.layers.scale(inputs, scale = scale, bias = 1.0)
14113+
14114+
exe = fluid.Executor(fluid.CPUPlace())
14115+
exe.run(fluid.default_startup_program())
14116+
14117+
img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
14118+
scale_np = np.array([2.]).astype(np.float32)
14119+
14120+
res = exe.run(fluid.default_main_program(), feed={'x':img, 'scale':scale_np}, fetch_list=[output])
14121+
print(res) # [array([[ 3., 5., 7.], [ 9., 11., 13.]], dtype=float32)]
14122+
1410214123
"""
1410314124

1410414125
helper = LayerHelper('scale', **locals())
@@ -14108,15 +14129,18 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
1410814129
out = helper.create_variable(
1410914130
name=name, dtype=x.dtype, persistable=False)
1411014131

14132+
inputs = {'X': x}
14133+
attrs = {
14134+
'bias': float(bias),
14135+
'bias_after_scale': bias_after_scale,
14136+
}
14137+
if isinstance(scale, Variable):
14138+
inputs['ScaleTensor'] = scale
14139+
else:
14140+
attrs['scale'] = float(scale)
14141+
1411114142
helper.append_op(
14112-
type='scale',
14113-
inputs={'X': x},
14114-
outputs={'Out': out},
14115-
attrs={
14116-
'scale': float(scale),
14117-
'bias': float(bias),
14118-
'bias_after_scale': bias_after_scale
14119-
})
14143+
type='scale', inputs=inputs, outputs={'Out': out}, attrs=attrs)
1412014144
return helper.append_activation(out)
1412114145

1412214146

0 commit comments

Comments
 (0)