Skip to content

Commit eb7ed1b

Browse files
authored
Merge pull request #13897 from gmcather/develop
1.add position encoding 2.logloss in nn.py
2 parents e74267a + ba22624 commit eb7ed1b

File tree

5 files changed

+436
-0
lines changed

5 files changed

+436
-0
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, k
177177
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
178178
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
179179
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
180+
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
181+
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
180182
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
181183
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
182184
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/add_position_encoding_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class AddPositionEncodingOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"X(Input) of add_position_encoding_op should not be null.");
27+
PADDLE_ENFORCE(
28+
ctx->HasOutput("Out"),
29+
"Out(Output) of add_position_encoding_op should not be null.");
30+
31+
auto x_dims = ctx->GetInputDim("X");
32+
ctx->SetOutputDim("Out", x_dims);
33+
ctx->ShareLoD("X", /*->*/ "Out");
34+
}
35+
};
36+
37+
class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
38+
public:
39+
using framework::OperatorWithKernel::OperatorWithKernel;
40+
41+
void InferShape(framework::InferShapeContext* ctx) const override {
42+
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) must not be null.");
43+
PADDLE_ENFORCE(ctx->HasInput("Out"), "Out must not be null.");
44+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
45+
"Out@GRAD must not be null.");
46+
47+
auto out_dims = ctx->GetInputDim("Out");
48+
if (ctx->HasOutput(framework::GradVarName("X"))) {
49+
ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
50+
}
51+
}
52+
};
53+
54+
class AddPositionEncodingOpMaker : public framework::OpProtoAndCheckerMaker {
55+
public:
56+
void Make() override {
57+
AddInput("X", "Input of AddPositionEncoding operator");
58+
AddOutput("Out", "Output of AddPositionEncoding operator");
59+
AddAttr<float>("alpha", "The scale of Original Embedding.")
60+
.SetDefault(1.0f)
61+
.AddCustomChecker([](const float& alpha) {
62+
PADDLE_ENFORCE(alpha >= 0.0f, "'alpha' must be above 0.0.");
63+
});
64+
AddAttr<float>("beta", "The scale of Position Embedding.")
65+
.SetDefault(1.0f)
66+
.AddCustomChecker([](const float& beta) {
67+
PADDLE_ENFORCE(beta >= 0.0f, "'beta' must be between 0.0.");
68+
});
69+
AddComment(R"DOC(
70+
Add Position Encoding Operator.
71+
72+
The add position encoding calculates the output based on the input, alpha, beta.
73+
The size of each dimension of the parameters checked in the infer-shape.
74+
)DOC");
75+
}
76+
};
77+
78+
} // namespace operators
79+
} // namespace paddle
80+
81+
namespace ops = paddle::operators;
82+
namespace plt = paddle::platform;
83+
84+
REGISTER_OPERATOR(add_position_encoding, ops::AddPositionEncodingOp,
85+
ops::AddPositionEncodingOpMaker,
86+
paddle::framework::DefaultGradOpDescMaker<true>);
87+
REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad);
88+
89+
REGISTER_OP_CPU_KERNEL(
90+
add_position_encoding,
91+
ops::AddPositionEncodingKernel<plt::CPUDeviceContext, float>,
92+
ops::AddPositionEncodingKernel<plt::CPUDeviceContext, double>);
93+
94+
REGISTER_OP_CPU_KERNEL(
95+
add_position_encoding_grad,
96+
ops::AddPositionEncodingGradKernel<plt::CPUDeviceContext, float>,
97+
ops::AddPositionEncodingGradKernel<plt::CPUDeviceContext, double>);
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/detail/safe_ref.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
class AddPositionEncodingKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& context) const override {
27+
auto* X = context.Input<framework::LoDTensor>("X");
28+
auto& x_lod = X->lod();
29+
auto* src_ptr = X->data<T>();
30+
31+
auto* Out = context.Output<framework::LoDTensor>("Out");
32+
auto* dst_ptr = Out->mutable_data<T>(context.GetPlace());
33+
34+
float alpha = context.Attr<float>("alpha");
35+
float beta = context.Attr<float>("beta");
36+
37+
auto x_dim = X->dims();
38+
int batch_size = 0;
39+
int max_seq_len = 0;
40+
int enc_size = 0;
41+
42+
if (x_lod.empty()) {
43+
PADDLE_ENFORCE(
44+
x_dim.size() == 3UL,
45+
"The input X of Add Position Encoding should be 3-D Tensor!");
46+
batch_size = x_dim[0];
47+
max_seq_len = x_dim[1];
48+
enc_size = x_dim[2];
49+
} else {
50+
PADDLE_ENFORCE(
51+
x_dim.size() == 2UL,
52+
"The input X of Add Position Encoding should be 2-D LoDTensor!");
53+
PADDLE_ENFORCE(
54+
x_lod.size() == 1UL,
55+
"The Add Position Encoding Op only supports lod_level == 1!");
56+
batch_size = x_lod[0].size() - 1;
57+
max_seq_len = -1;
58+
enc_size = x_dim[1];
59+
}
60+
61+
PADDLE_ENFORCE(enc_size % 2 == 0, "Only support even encode size!");
62+
63+
const int half_size = enc_size / 2;
64+
for (int i = 0; i < batch_size; ++i) {
65+
const int max_length =
66+
x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i];
67+
for (int j = 0; j < max_length; ++j) {
68+
for (int k = 0; k < half_size; ++k) {
69+
const double val = (half_size > 1)
70+
? j / pow(10000.0, double(k) / (half_size - 1))
71+
: j / 10000.0;
72+
dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta;
73+
dst_ptr[half_size + k] =
74+
src_ptr[half_size + k] * alpha + cos(val) * beta;
75+
}
76+
src_ptr += enc_size;
77+
dst_ptr += enc_size;
78+
}
79+
}
80+
}
81+
};
82+
83+
template <typename DeviceContext, typename T>
84+
class AddPositionEncodingGradKernel : public framework::OpKernel<T> {
85+
public:
86+
void Compute(const framework::ExecutionContext& context) const override {
87+
auto* dOut =
88+
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
89+
auto dout = framework::EigenVector<T>::Flatten(*dOut);
90+
91+
auto* dX =
92+
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
93+
dX->mutable_data<T>(context.GetPlace());
94+
auto dx = framework::EigenVector<T>::Flatten(*dX);
95+
96+
float alpha = context.Attr<float>("alpha");
97+
98+
auto* place =
99+
context.template device_context<DeviceContext>().eigen_device();
100+
dx.device(*place) = dout * static_cast<T>(alpha);
101+
}
102+
};
103+
104+
} // namespace operators
105+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
'sequence_reverse',
158158
'affine_channel',
159159
'hash',
160+
'log_loss',
161+
'add_position_encoding',
160162
]
161163

162164

@@ -7580,3 +7582,99 @@ def hash(input, hash_size, num_hash=1, name=None):
75807582
attrs={'num_hash': num_hash,
75817583
'mod_by': hash_size})
75827584
return out
7585+
7586+
7587+
def log_loss(input, label, epsilon=1e-4, name=None):
7588+
"""
7589+
**Negative Log Loss Layer**
7590+
7591+
This layer accepts input predictions and target label and returns the
7592+
negative log loss.
7593+
7594+
.. math::
7595+
7596+
Out = -label * \\log{(input + \\epsilon)}
7597+
- (1 - label) * \\log{(1 - input + \\epsilon)}
7598+
7599+
Args:
7600+
input (Variable|list): a 2-D tensor with shape [N x 1], where N is the
7601+
batch size. This input is a probability computed
7602+
by the previous operator.
7603+
label (Variable|list): the ground truth which is a 2-D tensor with
7604+
shape [N x 1], where N is the batch size.
7605+
epsilon (float): epsilon
7606+
name (string): the name of log_loss
7607+
7608+
Returns:
7609+
Variable: A 2-D tensor with shape [N x 1], the negative log loss.
7610+
7611+
Examples:
7612+
.. code-block:: python
7613+
7614+
prob = fluid.layers.sigmoid(net)
7615+
cost = fluid.layers.log_loss(input=prob, label=label)
7616+
"""
7617+
helper = LayerHelper('log_loss', **locals())
7618+
7619+
if name is None:
7620+
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
7621+
else:
7622+
loss = helper.create_variable(
7623+
name=name, dtype=input.dtype, persistable=False)
7624+
7625+
helper.append_op(
7626+
type='log_loss',
7627+
inputs={'Predicted': [input],
7628+
'Labels': [label]},
7629+
outputs={'Loss': [loss]},
7630+
attrs={'epsilon': epsilon})
7631+
return loss
7632+
7633+
7634+
def add_position_encoding(input, alpha, beta, name=None):
7635+
"""
7636+
**Add Position Encoding Layer**
7637+
7638+
This layer accepts an input 3D-Tensor of shape [N x M x P], and return an
7639+
output Tensor of shape [N x M x P] with positional encoding value.
7640+
7641+
Refer to `Attention Is All You Need<http://arxiv.org/pdf/1706.03762.pdf>`_ .
7642+
7643+
.. math::
7644+
PE(pos, 2i) = \\sin{(pos / 10000^{2i / P})} \\\\
7645+
PE(pos, 2i + 1) = \\cos{(pos / 10000^{2i / P})} \\\\
7646+
Out(:, pos, i) = \\alpha * input(:, pos, i) + \\beta * PE(pos, i)
7647+
7648+
Where:
7649+
* PE(pos, 2i): the increment for the number at even position
7650+
* PE(pos, 2i + 1): the increment for the number at odd position
7651+
7652+
Args:
7653+
input (Variable): 3-D input tensor with shape [N x M x P]
7654+
alpha (float): multiple of Input Tensor
7655+
beta (float): multiple of Positional Encoding Tensor
7656+
name (string): the name of position encoding layer
7657+
7658+
Returns:
7659+
Variable: A 3-D Tensor of shape [N x M x P] with positional encoding.
7660+
7661+
Examples:
7662+
.. code-block:: python
7663+
7664+
position_tensor = fluid.layers.add_position_encoding(input=tensor)
7665+
"""
7666+
helper = LayerHelper('add_position_encoding', **locals())
7667+
dtype = helper.input_dtype()
7668+
7669+
if name is None:
7670+
out = helper.create_variable_for_type_inference(dtype=dtype)
7671+
else:
7672+
out = helper.create_variable(name=name, dtype=dtype, persistable=False)
7673+
7674+
helper.append_op(
7675+
type="add_position_encoding",
7676+
inputs={"X": input},
7677+
outputs={"Out": out},
7678+
attrs={"alpha": alpha,
7679+
"beta": beta})
7680+
return out

0 commit comments

Comments
 (0)