Skip to content

Commit c8801e1

Browse files
committed
grad diff problem to be fixed and need api spec change to be done
1 parent f37bd03 commit c8801e1

File tree

8 files changed

+324
-60
lines changed

8 files changed

+324
-60
lines changed

paddle/fluid/framework/selected_rows.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ class SelectedRows {
133133
// SelectedRows are simply concated when adding together. Until a
134134
// SelectedRows add a Tensor, will the duplicate rows be handled.
135135
Vector<int64_t> rows_;
136-
std::unordered_map<int64_t, int64_t> id_to_index_;
136+
std::unordered_map<int64_t, int64_t>
137+
id_to_index_; // should not be used when ids has duplicate member
137138
std::unique_ptr<Tensor> value_{nullptr};
138139
int64_t height_;
139140
std::unique_ptr<RWLock> rwlock_{nullptr};

paddle/fluid/operators/hierarchical_sigmoid_op.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,19 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
9191
AddInput("W",
9292
"(Tensor, required), The parameters of hierarchical "
9393
"sigmoid operator, each of them is a 2-D tensor, the shape is"
94-
"[num_classes - 1, D].");
94+
"[K, D]. Which K is the num of non-leaf node in Path Tree");
9595
AddInput("Label",
9696
"(Tensor, required), The labels of training data. It's a"
9797
"tensor with shape [N, 1].");
98+
AddInput("PTable",
99+
"(Tensor, optional), The Path Table from root to current word"
100+
"it should have shape like [N, L], L is the length of the Path")
101+
.AsDispensable();
102+
AddInput("PCode",
103+
"(Tensor, optional), The Code on each Node of the Path from root "
104+
"to current word"
105+
"it should have shape like [N, L], L is the length of the Path")
106+
.AsDispensable();
98107
AddInput("Bias",
99108
"(Tensor, optional), The bias is a tensor with shape"
100109
"[1, num_classes - 1].");

paddle/fluid/operators/hierarchical_sigmoid_op.h

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <iostream>
1717
#include <vector>
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/framework/selected_rows.h"
1920
#include "paddle/fluid/operators/clip_op.h"
2021
#include "paddle/fluid/operators/math/math_function.h"
2122
#include "paddle/fluid/operators/math/matrix_bit_code.h"
@@ -34,12 +35,21 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
3435
void Compute(const framework::ExecutionContext& ctx) const override {
3536
auto* in = ctx.Input<framework::Tensor>("X");
3637
auto* w = ctx.Input<framework::Tensor>("W");
38+
auto* path = ctx.Input<framework::Tensor>("PTable");
39+
auto* code = ctx.Input<framework::Tensor>("PCode");
3740
auto* label = ctx.Input<framework::Tensor>("Label");
3841
auto* bias = ctx.Input<framework::Tensor>("Bias");
3942
auto* out = ctx.Output<framework::Tensor>("Out");
4043
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
4144
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
42-
int64_t code_length = math::FindLastSet(num_classes - 1);
45+
bool is_custom = false;
46+
if (path) {
47+
is_custom = true;
48+
} else {
49+
is_custom = false;
50+
}
51+
int64_t code_length =
52+
path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
4353
int64_t batch_size = in->dims()[0];
4454
framework::Tensor sum;
4555
auto& dev_ctx = ctx.template device_context<DeviceContext>();
@@ -52,23 +62,31 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
5262
zero(dev_ctx, pre_out, static_cast<T>(0.0));
5363
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
5464
math::RowwiseSum<DeviceContext, T> row_sum;
55-
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
65+
66+
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
67+
if (!is_custom) {
68+
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
69+
label->data<int64_t>()));
70+
} else {
71+
bit_code.reset(new math::MatrixBitCodeFunctor<T>(path, code,
72+
label->data<int64_t>()));
73+
}
5674

5775
std::vector<int64_t> sum_dims({batch_size, 1UL});
5876
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
5977
auto sum_mat = EigenMatrix<T>::From(sum);
6078
out->mutable_data<T>(ctx.GetPlace());
6179
auto out_mat = framework::EigenVector<T>::Flatten(*out);
6280
if (bias) {
63-
bit_code.Add(pre_out, *bias);
81+
bit_code->Add(pre_out, *bias);
6482
}
65-
bit_code.Mul(pre_out, *w, *in);
83+
bit_code->Mul(pre_out, *w, *in);
6684
// clip to [-40, 40]
6785
Transform<DeviceContext> trans;
6886
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
6987
pre_out_data + pre_out->numel(), pre_out_data,
7088
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
71-
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
89+
bit_code->Sum(*pre_out, out, static_cast<T>(-1));
7290
// use softrelu to calculate cross entropy
7391
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
7492
row_sum(dev_ctx, *pre_out, &sum);
@@ -86,6 +104,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
86104
void Compute(const framework::ExecutionContext& ctx) const override {
87105
auto* in = ctx.Input<framework::Tensor>("X");
88106
auto* w = ctx.Input<framework::Tensor>("W");
107+
auto* path = ctx.Input<framework::Tensor>("PTable");
108+
auto* code = ctx.Input<framework::Tensor>("PCode");
89109
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
90110
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
91111
auto* bias_grad =
@@ -105,7 +125,22 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
105125
zero(dev_ctx, w_grad, static_cast<T>(0.0));
106126

107127
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
108-
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
128+
129+
bool is_custom = false;
130+
if (path) {
131+
is_custom = true;
132+
} else {
133+
is_custom = false;
134+
}
135+
136+
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
137+
if (!is_custom) {
138+
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
139+
label->data<int64_t>()));
140+
} else {
141+
bit_code.reset(new math::MatrixBitCodeFunctor<T>(path, code,
142+
label->data<int64_t>()));
143+
}
109144

110145
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
111146
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
@@ -116,18 +151,18 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
116151
// softrelu derivative
117152
pre_out_grad_mat.device(place) =
118153
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
119-
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
154+
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
120155
pre_out_grad_mat.device(place) =
121156
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
122157
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
123158
// be consistent with the clipping in forward.
124159
if (bias_grad) {
125160
bias_grad->mutable_data<T>(ctx.GetPlace());
126161
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
127-
bit_code.AddGrad(pre_out_grad, bias_grad);
162+
bit_code->AddGrad(pre_out_grad, bias_grad);
128163
}
129-
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
130-
bit_code.MulGradError(pre_out_grad, *w, in_grad);
164+
bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
165+
bit_code->MulGradError(pre_out_grad, *w, in_grad);
131166
}
132167
};
133168

paddle/fluid/operators/math/matrix_bit_code.cc

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ namespace math {
2121
template <typename T>
2222
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
2323
const framework::Tensor& vec) {
24-
SimpleCodeTable code_table(num_classes_);
2524
size_t batch_size = tmat->dims()[0];
2625
size_t width = tmat->dims()[1];
2726
for (size_t i = 0; i < batch_size; ++i) {
28-
auto code = code_table(static_cast<size_t>(ids_[i]));
29-
int code_length = code.get_length();
27+
auto code = code_table->get_code(i);
28+
int code_length = code->get_length();
3029
for (int j = 0; j < code_length; ++j) {
31-
size_t index = code.calc_index(j);
30+
size_t index = code->calc_index(j);
3231
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
3332
}
3433
}
@@ -37,14 +36,13 @@ void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
3736
template <typename T>
3837
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
3938
framework::Tensor* vec) {
40-
SimpleCodeTable code_table(num_classes_);
4139
size_t batch_size = tmat.dims()[0];
4240
size_t width = tmat.dims()[1];
4341
for (size_t i = 0; i < batch_size; ++i) {
44-
auto code = code_table(static_cast<size_t>(ids_[i]));
45-
int code_length = code.get_length();
42+
auto code = code_table->get_code(i);
43+
int code_length = code->get_length();
4644
for (int j = 0; j < code_length; ++j) {
47-
size_t index = code.calc_index(j);
45+
size_t index = code->calc_index(j);
4846
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
4947
}
5048
}
@@ -53,15 +51,14 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
5351
template <typename T>
5452
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
5553
framework::Tensor* sum, T scale_sum) {
56-
SimpleCodeTable code_table(num_classes_);
5754
size_t num_samples = tmat.dims()[0];
5855
size_t o_width = tmat.dims()[1];
5956
for (size_t i = 0; i < num_samples; ++i) {
6057
T sm = static_cast<T>(0.0);
61-
auto code = code_table(static_cast<size_t>(ids_[i]));
62-
int code_length = code.get_length();
58+
auto code = code_table->get_code(i);
59+
int code_length = code->get_length();
6360
for (int j = 0; j < code_length; ++j) {
64-
if (code.calc_bit(j)) {
61+
if (code->calc_bit(j)) {
6562
// calc_bit starts from right most bit, while data in tmat[i] is in the
6663
// reverse order.
6764
sm += tmat.data<T>()[i * o_width + j];
@@ -75,7 +72,6 @@ template <typename T>
7572
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
7673
const framework::Tensor& weight,
7774
const framework::Tensor& input) {
78-
SimpleCodeTable code_table(num_classes_);
7975
size_t num_samples = tmat->dims()[0];
8076
size_t tmat_width = tmat->dims()[1];
8177
size_t input_width = input.dims()[1];
@@ -84,10 +80,10 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
8480
auto weight_value = weight.data<T>();
8581
auto input_value = input.data<T>();
8682
for (size_t i = 0; i < num_samples; ++i) {
87-
auto code = code_table(static_cast<size_t>(ids_[i]));
88-
int code_length = code.get_length();
83+
auto code = code_table->get_code(i);
84+
int code_length = code->get_length();
8985
for (int j = 0; j < code_length; ++j) {
90-
size_t index = code.calc_index(j);
86+
size_t index = code->calc_index(j);
9187
T sum = static_cast<T>(0.0);
9288
for (size_t k = 0; k < input_width; ++k) {
9389
sum += weight_value[weight_width * index + k] *
@@ -102,7 +98,6 @@ template <typename T>
10298
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
10399
framework::Tensor* weight,
104100
const framework::Tensor& input) {
105-
SimpleCodeTable code_table(num_classes_);
106101
size_t num_samples = tmat.dims()[0];
107102
size_t input_width = input.dims()[1];
108103
size_t tmat_width = tmat.dims()[1];
@@ -111,10 +106,10 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
111106
auto weight_value = weight->data<T>();
112107
auto input_value = input.data<T>();
113108
for (size_t i = 0; i < num_samples; ++i) {
114-
auto code = code_table(static_cast<size_t>(ids_[i]));
115-
int code_length = code.get_length();
109+
auto code = code_table->get_code(i);
110+
int code_length = code->get_length();
116111
for (int j = 0; j < code_length; ++j) {
117-
size_t index = code.calc_index(j);
112+
size_t index = code->calc_index(j);
118113

119114
for (size_t k = 0; k < input_width; ++k) {
120115
weight_value[weight_width * index + k] +=
@@ -128,7 +123,6 @@ template <typename T>
128123
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
129124
const framework::Tensor& weight,
130125
framework::Tensor* input) {
131-
SimpleCodeTable code_table(num_classes_);
132126
size_t num_samples = tmat.dims()[0];
133127
size_t tmat_width = tmat.dims()[1];
134128
size_t input_width = input->dims()[1];
@@ -138,10 +132,10 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
138132
auto input_value = input->data<T>();
139133

140134
for (size_t i = 0; i < num_samples; ++i) {
141-
auto code = code_table(static_cast<size_t>(ids_[i]));
142-
int code_length = code.get_length();
135+
auto code = code_table->get_code(i);
136+
int code_length = code->get_length();
143137
for (int j = 0; j < code_length; ++j) {
144-
size_t index = code.calc_index(j);
138+
size_t index = code->calc_index(j);
145139

146140
for (size_t k = 0; k < input_width; ++k) {
147141
input_value[input_width * i + k] +=
@@ -154,14 +148,13 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
154148

155149
template <typename T>
156150
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
157-
SimpleCodeTable code_table(num_classes_);
158151
size_t num_samples = tmat->dims()[0];
159152
size_t o_width = tmat->dims()[1];
160153
for (size_t i = 0; i < num_samples; ++i) {
161-
auto code = code_table(static_cast<size_t>(ids_[i]));
162-
int code_length = code.get_length();
154+
auto code = code_table->get_code(i);
155+
int code_length = code->get_length();
163156
for (int j = 0; j < code_length; ++j) {
164-
if (code.calc_bit(j)) {
157+
if (code->calc_bit(j)) {
165158
tmat->data<T>()[i * o_width + j] -= 1;
166159
}
167160
}

0 commit comments

Comments
 (0)