Skip to content

Commit 39d88eb

Browse files
committed
Merge branch 'fix-optimizer-accumulator' of ssh://github.com/jacquesqiao/Paddle into distribute-transpiler-handle-adam-accumulator
2 parents 5c12c5e + 3748aa4 commit 39d88eb

17 files changed

+1537
-753
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
259259
op_library(sequence_conv_op DEPS context_project)
260260
op_library(sequence_pool_op DEPS sequence_pooling)
261261
op_library(lstm_op DEPS sequence2batch lstm_compute)
262+
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
262263
op_library(lstmp_op DEPS sequence2batch lstm_compute)
263264
op_library(gru_op DEPS sequence2batch gru_compute)
264265
op_library(recurrent_op DEPS executor)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
16+
#include <vector>
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
/**
22+
* Organize the classes into a binary tree. At each node, a sigmoid function
23+
* is used to calculate the probability of belonging to the right branch.
24+
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
25+
* Hierarchical Probabilistic Neural Network Language Model."
26+
*
27+
* Here we uses a simple way of making the binary tree.
28+
* Assuming the number of classes C = 6,
29+
* The classes are organized as a binary tree in the following way:
30+
*
31+
* @code{.py}
32+
* *-*-*- 2
33+
* | | |- 3
34+
* | |
35+
* | |-*- 4
36+
* | |- 5
37+
* |
38+
* |-*- 0
39+
* |- 1
40+
* @endcode
41+
*
42+
* where * indicates an internal node, and each leaf node represents a class.
43+
* - Node 0 ... C-2 are internal nodes.
44+
* - Node C-1 ... 2C-2 are leaf nodes.
45+
* - Class c is represented by leaf node \f$c+C-1\f$.
46+
*
47+
* We assign an id for each node:
48+
* - the id of root be 0.
49+
* - the left child of a node i is 2*i+1.
50+
* - the right child of a node i is 2*i+2.
51+
*
52+
* It's easy to see that:
53+
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
54+
* - the j-th level ancestor of node i is
55+
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
56+
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
57+
*
58+
*/
59+
60+
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
61+
public:
62+
using framework::OperatorWithKernel::OperatorWithKernel;
63+
void InferShape(framework::InferShapeContext* ctx) const override {
64+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
65+
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
66+
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
67+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
68+
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
69+
"Output(PreOut) should not be null.");
70+
const int64_t batch_size = ctx->GetInputDim("X")[0];
71+
std::vector<int64_t> output_shape({batch_size, 1});
72+
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
73+
}
74+
75+
protected:
76+
framework::OpKernelType GetExpectedKernelType(
77+
const framework::ExecutionContext& ctx) const override {
78+
return framework::OpKernelType(
79+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
80+
ctx.GetPlace());
81+
}
82+
};
83+
84+
template <typename AttrType>
85+
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
86+
public:
87+
void Make() override {
88+
AddInput("X",
89+
"(Tensor, required) The input tensor with shape [N, D], "
90+
"where N is the size of mini-batch, and D is the feature size.");
91+
AddInput("W",
92+
"(Tensor, required), The parameters of hierarchical "
93+
"sigmoid operator, each of them is a 2-D tensor, the shape is"
94+
"[num_classes - 1, D].");
95+
AddInput("Label",
96+
"(Tensor, required), The labels of training data. It's a"
97+
"tensor with shape [N, 1].");
98+
AddInput("Bias",
99+
"(Tensor, optional), The bias is a tensor with shape"
100+
"[1, num_classes - 1].");
101+
AddOutput("Out",
102+
"(Tensor, required) The output of hierarchical sigmoid operator."
103+
"The shape is [N, 1].");
104+
AddOutput("PreOut",
105+
"(Tensor, required) A intermedia 2-D tensor with shape "
106+
"[batch_size, code_length], where code_length represents the "
107+
"maximum path length from root to leaf nodes.")
108+
.AsIntermediate();
109+
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
110+
.SetDefault(2);
111+
AddComment(R"DOC(
112+
The hierarchical sigmoid operator organize the classes into a binary tree.
113+
At each node, a sigmoid function is used to calculate the probability of
114+
belonging to the right branch. This idea is from
115+
"F. Morin, Y. Bengio (AISTATS 05):
116+
Hierarchical Probabilistic Neural Network Language Model."
117+
)DOC");
118+
}
119+
};
120+
121+
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
122+
public:
123+
using framework::OperatorWithKernel::OperatorWithKernel;
124+
void InferShape(framework::InferShapeContext* ctx) const override {
125+
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
126+
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
127+
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
128+
"Input(Preout) should not be null.");
129+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
130+
"Output(W@Grad should not be null.)");
131+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
132+
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
133+
ctx->SetOutputDim(framework::GradVarName("Bias"),
134+
ctx->GetInputDim("Bias"));
135+
}
136+
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
137+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
138+
}
139+
140+
protected:
141+
framework::OpKernelType GetExpectedKernelType(
142+
const framework::ExecutionContext& ctx) const override {
143+
return framework::OpKernelType(
144+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
145+
ctx.GetPlace());
146+
}
147+
};
148+
149+
} // namespace operators
150+
} // namespace paddle
151+
152+
namespace ops = paddle::operators;
153+
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
154+
ops::HierarchicalSigmoidOpMaker<int>,
155+
paddle::framework::DefaultGradOpDescMaker<true>);
156+
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
157+
REGISTER_OP_CPU_KERNEL(
158+
hierarchical_sigmoid,
159+
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
160+
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
161+
double>);
162+
REGISTER_OP_CPU_KERNEL(
163+
hierarchical_sigmoid_grad,
164+
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
165+
float>,
166+
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
167+
double>);
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <iostream>
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/clip_op.h"
20+
#include "paddle/fluid/operators/math/math_function.h"
21+
#include "paddle/fluid/operators/math/matrix_bit_code.h"
22+
#include "paddle/fluid/platform/transform.h"
23+
namespace paddle {
24+
namespace operators {
25+
26+
template <typename T, int MajorType = Eigen::RowMajor,
27+
typename IndexType = Eigen::DenseIndex>
28+
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
29+
using platform::Transform;
30+
31+
template <typename DeviceContext, typename T>
32+
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
33+
public:
34+
void Compute(const framework::ExecutionContext& ctx) const override {
35+
auto* in = ctx.Input<framework::Tensor>("X");
36+
auto* w = ctx.Input<framework::Tensor>("W");
37+
auto* label = ctx.Input<framework::Tensor>("Label");
38+
auto* bias = ctx.Input<framework::Tensor>("Bias");
39+
auto* out = ctx.Output<framework::Tensor>("Out");
40+
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
41+
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
42+
int64_t code_length = math::FindLastSet(num_classes - 1);
43+
int64_t batch_size = in->dims()[0];
44+
framework::Tensor sum;
45+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
46+
auto* pre_out_data = pre_out->mutable_data<T>(
47+
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
48+
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
49+
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
50+
// 0s can avoid out of path's loss.
51+
math::SetConstant<DeviceContext, T> zero;
52+
zero(dev_ctx, pre_out, static_cast<T>(0.0));
53+
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
54+
math::RowwiseSum<DeviceContext, T> row_sum;
55+
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
56+
57+
std::vector<int64_t> sum_dims({batch_size, 1UL});
58+
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
59+
auto sum_mat = EigenMatrix<T>::From(sum);
60+
out->mutable_data<T>(ctx.GetPlace());
61+
auto out_mat = framework::EigenVector<T>::Flatten(*out);
62+
if (bias) {
63+
bit_code.Add(pre_out, *bias);
64+
}
65+
bit_code.Mul(pre_out, *w, *in);
66+
// clip to [-40, 40]
67+
Transform<DeviceContext> trans;
68+
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
69+
pre_out_data + pre_out->numel(), pre_out_data,
70+
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
71+
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
72+
// use softrelu to calculate cross entropy
73+
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
74+
row_sum(dev_ctx, *pre_out, &sum);
75+
// TODO(guosheng): Subtract the out of path's loss, since not all
76+
// class(leaf) nodes' path lengths equal code_length. But it won't break the
77+
// gradient check since both have the out of path's loss and will cancel out
78+
// each other.
79+
out_mat.device(place) = sum_mat + out_mat;
80+
}
81+
};
82+
83+
template <typename DeviceContext, typename T>
84+
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
85+
public:
86+
void Compute(const framework::ExecutionContext& ctx) const override {
87+
auto* in = ctx.Input<framework::Tensor>("X");
88+
auto* w = ctx.Input<framework::Tensor>("W");
89+
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
90+
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
91+
auto* bias_grad =
92+
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
93+
auto* label = ctx.Input<framework::Tensor>("Label");
94+
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
95+
auto* out_grad =
96+
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
97+
framework::Tensor pre_out_grad;
98+
99+
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
100+
in_grad->mutable_data<T>(ctx.GetPlace());
101+
w_grad->mutable_data<T>(ctx.GetPlace());
102+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
103+
math::SetConstant<DeviceContext, T> zero;
104+
zero(dev_ctx, in_grad, static_cast<T>(0.0));
105+
zero(dev_ctx, w_grad, static_cast<T>(0.0));
106+
107+
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
108+
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
109+
110+
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
111+
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
112+
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
113+
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
114+
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
115+
116+
// softrelu derivative
117+
pre_out_grad_mat.device(place) =
118+
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
119+
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
120+
pre_out_grad_mat.device(place) =
121+
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
122+
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
123+
// be consistent with the clipping in forward.
124+
if (bias_grad) {
125+
bias_grad->mutable_data<T>(ctx.GetPlace());
126+
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
127+
bit_code.AddGrad(pre_out_grad, bias_grad);
128+
}
129+
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
130+
bit_code.MulGradError(pre_out_grad, *w, in_grad);
131+
}
132+
};
133+
134+
} // namespace operators
135+
} // namespace paddle

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ math_library(sequence_padding)
5151
math_library(sequence_pooling DEPS math_function)
5252
math_library(sequence_scale)
5353
math_library(softmax DEPS math_function)
54+
math_library(matrix_bit_code)
5455
math_library(unpooling)
5556
math_library(vol2col)
5657

paddle/fluid/operators/math/math_function_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class RowwiseSum<platform::CPUDeviceContext, T> {
155155
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
156156
auto height = in_dims[0];
157157
auto size = in_dims[1];
158-
PADDLE_ENFORCE_EQ(out->numel(), size);
158+
PADDLE_ENFORCE_EQ(out->numel(), height);
159159

160160
T* out_buf = out->mutable_data<T>(out->place());
161161
const T* in_buf = input.data<T>();

0 commit comments

Comments
 (0)