Skip to content

Commit 8a0c7e2

Browse files
authored
Merge pull request #10280 from reyoung/feature/add_stable_test_of_cross_entropy
Clean cross entropy
2 parents 4a497b8 + 53c2768 commit 8a0c7e2

File tree

3 files changed

+89
-137
lines changed

3 files changed

+89
-137
lines changed

paddle/fluid/operators/cross_entropy_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,13 @@ or not. But the output only shares the LoD information with input X.
164164
} // namespace paddle
165165

166166
namespace ops = paddle::operators;
167+
using CPUCtx = paddle::platform::CPUDeviceContext;
168+
167169
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
168170
paddle::framework::DefaultGradOpDescMaker<true>);
169171
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
170-
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>,
171-
ops::CrossEntropyOpKernel<double>);
172+
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
173+
ops::CrossEntropyOpKernel<CPUCtx, double>);
172174
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
173-
ops::CrossEntropyGradientOpKernel<float>,
174-
ops::CrossEntropyGradientOpKernel<double>);
175+
ops::CrossEntropyGradientOpKernel<CPUCtx, float>,
176+
ops::CrossEntropyGradientOpKernel<CPUCtx, double>);

paddle/fluid/operators/cross_entropy_op.cu

Lines changed: 6 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,98 +14,11 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/cross_entropy_op.h"
1616

17-
namespace paddle {
18-
namespace operators {
19-
20-
namespace {
21-
22-
template <typename T>
23-
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
24-
const int64_t* label, const int N,
25-
const int D) {
26-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
27-
i += blockDim.x * gridDim.x) {
28-
int idx = i * D + label[i];
29-
dX[idx] = -dY[i] / X[idx];
30-
}
31-
}
32-
33-
template <typename T>
34-
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
35-
const T* label, const int N,
36-
const int D) {
37-
int ids = blockIdx.x * blockDim.x + threadIdx.x;
38-
if (ids < N * D) {
39-
int row_ids = ids / D;
40-
dX[ids] = -label[ids] * dY[row_ids] / X[ids];
41-
}
42-
}
43-
} // namespace
44-
45-
template <typename T>
46-
class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
47-
public:
48-
void Compute(const framework::ExecutionContext& ctx) const override {
49-
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
50-
"This kernel only runs on GPU device.");
51-
const Tensor* x = ctx.Input<Tensor>("X");
52-
const Tensor* label = ctx.Input<Tensor>("Label");
53-
Tensor* y = ctx.Output<Tensor>("Y");
54-
y->mutable_data<T>(ctx.GetPlace());
55-
56-
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
57-
ctx.template device_context<platform::CUDADeviceContext>(), y, x, label,
58-
ctx.Attr<bool>("soft_label"));
59-
}
60-
};
61-
62-
template <typename T>
63-
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
64-
public:
65-
void Compute(const framework::ExecutionContext& ctx) const override {
66-
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
67-
"This kernel only runs on GPU device.");
68-
69-
const Tensor* x = ctx.Input<Tensor>("X");
70-
const Tensor* label = ctx.Input<Tensor>("Label");
71-
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
72-
dx->mutable_data<T>(ctx.GetPlace());
73-
74-
const T* dy_data =
75-
ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
76-
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
77-
const T* x_data = x->data<T>();
78-
79-
int64_t batch_size = x->dims()[0];
80-
int64_t class_num = x->dims()[1];
81-
82-
int block = 512;
83-
int grid = (batch_size * class_num + block - 1) / block;
84-
85-
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
86-
auto stream = dev_ctx.stream();
87-
88-
if (ctx.Attr<bool>("soft_label")) {
89-
auto* label_data = label->data<T>();
90-
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
91-
dx_data, dy_data, x_data, label_data, batch_size, class_num);
92-
} else {
93-
math::SetConstant<platform::CUDADeviceContext, T> functor;
94-
functor(dev_ctx, dx, 0);
95-
auto* label_data = label->data<int64_t>();
96-
grid = (batch_size + block - 1) / block;
97-
CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
98-
dx_data, dy_data, x_data, label_data, batch_size, class_num);
99-
}
100-
}
101-
};
102-
103-
} // namespace operators
104-
} // namespace paddle
105-
10617
namespace ops = paddle::operators;
107-
REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
108-
ops::CrossEntropyOpCUDAKernel<double>);
18+
using CUDACtx = paddle::platform::CUDADeviceContext;
19+
REGISTER_OP_CUDA_KERNEL(cross_entropy,
20+
ops::CrossEntropyOpKernel<CUDACtx, float>,
21+
ops::CrossEntropyOpKernel<CUDACtx, double>);
10922
REGISTER_OP_CUDA_KERNEL(cross_entropy_grad,
110-
ops::CrossEntropyGradientOpCUDAKernel<float>,
111-
ops::CrossEntropyGradientOpCUDAKernel<double>);
23+
ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
24+
ops::CrossEntropyGradientOpKernel<CUDACtx, double>);

paddle/fluid/operators/cross_entropy_op.h

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,69 +17,106 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/op_registry.h"
1818
#include "paddle/fluid/operators/math/cross_entropy.h"
1919
#include "paddle/fluid/operators/math/math_function.h"
20+
#include "paddle/fluid/platform/for_range.h"
2021

2122
namespace paddle {
2223
namespace operators {
2324

2425
using Tensor = framework::Tensor;
25-
template <typename T, int MajorType = Eigen::RowMajor,
26-
typename IndexType = Eigen::DenseIndex>
27-
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2826

29-
template <typename T>
27+
template <typename DeviceContext, typename T>
3028
class CrossEntropyOpKernel : public framework::OpKernel<T> {
3129
public:
3230
void Compute(const framework::ExecutionContext& ctx) const override {
33-
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
34-
"This kernel only runs on CPU.");
35-
const Tensor* x = ctx.Input<Tensor>("X");
36-
const Tensor* labels = ctx.Input<Tensor>("Label");
37-
Tensor* y = ctx.Output<Tensor>("Y");
31+
auto* x = ctx.Input<Tensor>("X");
32+
auto* labels = ctx.Input<Tensor>("Label");
33+
auto* y = ctx.Output<Tensor>("Y");
3834
y->mutable_data<T>(ctx.GetPlace());
3935

40-
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
41-
ctx.template device_context<platform::CPUDeviceContext>(), y, x, labels,
36+
math::CrossEntropyFunctor<DeviceContext, T>()(
37+
ctx.template device_context<DeviceContext>(), y, x, labels,
4238
ctx.Attr<bool>("soft_label"));
4339
}
4440
};
4541

4642
template <typename T>
43+
class XeSoftlabelGradFunctor {
44+
public:
45+
XeSoftlabelGradFunctor(T* dx,
46+
const T* dy, // NOLINT
47+
const T* x, // NOLINT
48+
const T* label, // NOLINT
49+
size_t num_classes)
50+
: dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
51+
52+
HOSTDEVICE void operator()(size_t i) {
53+
auto row_ids = i / num_classes_;
54+
dx_[i] = -label_[i] * dy_[row_ids] / x_[i];
55+
}
56+
57+
private:
58+
T* dx_;
59+
const T* dy_;
60+
const T* x_;
61+
const T* label_;
62+
size_t num_classes_;
63+
};
64+
65+
template <typename T>
66+
class XeGradFunctor {
67+
public:
68+
XeGradFunctor(T* dx,
69+
const T* dy, // NOLINT
70+
const T* x, // NOLINT
71+
const int64_t* label, // NOLINT
72+
size_t num_classes)
73+
: dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
74+
75+
HOSTDEVICE void operator()(size_t sample_id) {
76+
auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id];
77+
for (size_t x_offset = sample_id * num_classes_;
78+
x_offset < (sample_id + 1) * num_classes_; ++x_offset) {
79+
dx_[x_offset] = x_offset != x_is_true_offset
80+
? static_cast<T>(0)
81+
: -dy_[sample_id] / x_[x_offset];
82+
}
83+
}
84+
85+
private:
86+
T* dx_;
87+
const T* dy_;
88+
const T* x_;
89+
const int64_t* label_;
90+
size_t num_classes_;
91+
};
92+
93+
template <typename DeviceContext, typename T>
4794
class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
4895
public:
4996
void Compute(const framework::ExecutionContext& ctx) const override {
50-
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
51-
"This kernel only runs on CPU.");
52-
const Tensor* x = ctx.Input<Tensor>("X");
53-
const Tensor* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
54-
const Tensor* label = ctx.Input<Tensor>("Label");
55-
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
56-
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
97+
auto* x = ctx.Input<Tensor>("X");
98+
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
99+
auto* label = ctx.Input<Tensor>("Label");
100+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
101+
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
57102

58103
int64_t class_num = x->dims()[1];
59104
if (ctx.Attr<bool>("soft_label")) {
60-
auto x_mat = EigenMatrix<T>::From(*x);
61-
auto dy_mat = EigenMatrix<T>::From(*dy);
62-
auto lbl_mat = EigenMatrix<T>::From(*label);
63-
auto dx_mat = EigenMatrix<T>::From(*dx);
64-
65-
dx_mat.device(*ctx.template device_context<platform::CPUDeviceContext>()
66-
.eigen_device()) =
67-
-(lbl_mat *
68-
dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
105+
XeSoftlabelGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
106+
label->data<T>(),
107+
static_cast<size_t>(class_num));
108+
platform::ForRange<DeviceContext> for_range(
109+
ctx.template device_context<DeviceContext>(),
110+
static_cast<size_t>(dx->numel()));
111+
for_range(functor);
69112
} else {
70-
int64_t batch_size = x->dims()[0];
71-
const T* dy_data = dy->data<T>();
72-
const T* x_data = x->data<T>();
73-
const int64_t* label_data = label->data<int64_t>();
74-
75-
math::SetConstant<platform::CPUDeviceContext, T> functor;
76-
functor(ctx.template device_context<platform::CPUDeviceContext>(), dx, 0);
77-
78-
for (int64_t i = 0; i < batch_size; ++i) {
79-
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
80-
int64_t index = i * class_num + label_data[i];
81-
dx_data[index] = math::TolerableValue<T>()(-dy_data[i] / x_data[index]);
82-
}
113+
XeGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
114+
label->data<int64_t>(),
115+
static_cast<size_t>(class_num));
116+
platform::ForRange<DeviceContext> for_range(
117+
ctx.template device_context<DeviceContext>(),
118+
static_cast<size_t>(dy->numel()));
119+
for_range(functor);
83120
}
84121
}
85122
};

0 commit comments

Comments
 (0)