@@ -17,69 +17,106 @@ limitations under the License. */
17
17
#include " paddle/fluid/framework/op_registry.h"
18
18
#include " paddle/fluid/operators/math/cross_entropy.h"
19
19
#include " paddle/fluid/operators/math/math_function.h"
20
+ #include " paddle/fluid/platform/for_range.h"
20
21
21
22
namespace paddle {
22
23
namespace operators {
23
24
24
25
using Tensor = framework::Tensor;
25
- template <typename T, int MajorType = Eigen::RowMajor,
26
- typename IndexType = Eigen::DenseIndex>
27
- using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
28
26
29
- template <typename T>
27
+ template <typename DeviceContext, typename T>
30
28
class CrossEntropyOpKernel : public framework ::OpKernel<T> {
31
29
public:
32
30
void Compute (const framework::ExecutionContext& ctx) const override {
33
- PADDLE_ENFORCE (platform::is_cpu_place (ctx.GetPlace ()),
34
- " This kernel only runs on CPU." );
35
- const Tensor* x = ctx.Input <Tensor>(" X" );
36
- const Tensor* labels = ctx.Input <Tensor>(" Label" );
37
- Tensor* y = ctx.Output <Tensor>(" Y" );
31
+ auto * x = ctx.Input <Tensor>(" X" );
32
+ auto * labels = ctx.Input <Tensor>(" Label" );
33
+ auto * y = ctx.Output <Tensor>(" Y" );
38
34
y->mutable_data <T>(ctx.GetPlace ());
39
35
40
- math::CrossEntropyFunctor<platform::CPUDeviceContext , T>()(
41
- ctx.template device_context <platform::CPUDeviceContext >(), y, x, labels,
36
+ math::CrossEntropyFunctor<DeviceContext , T>()(
37
+ ctx.template device_context <DeviceContext >(), y, x, labels,
42
38
ctx.Attr <bool >(" soft_label" ));
43
39
}
44
40
};
45
41
46
42
template <typename T>
43
+ class XeSoftlabelGradFunctor {
44
+ public:
45
+ XeSoftlabelGradFunctor (T* dx,
46
+ const T* dy, // NOLINT
47
+ const T* x, // NOLINT
48
+ const T* label, // NOLINT
49
+ size_t num_classes)
50
+ : dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
51
+
52
+ HOSTDEVICE void operator ()(size_t i) {
53
+ auto row_ids = i / num_classes_;
54
+ dx_[i] = -label_[i] * dy_[row_ids] / x_[i];
55
+ }
56
+
57
+ private:
58
+ T* dx_;
59
+ const T* dy_;
60
+ const T* x_;
61
+ const T* label_;
62
+ size_t num_classes_;
63
+ };
64
+
65
+ template <typename T>
66
+ class XeGradFunctor {
67
+ public:
68
+ XeGradFunctor (T* dx,
69
+ const T* dy, // NOLINT
70
+ const T* x, // NOLINT
71
+ const int64_t * label, // NOLINT
72
+ size_t num_classes)
73
+ : dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
74
+
75
+ HOSTDEVICE void operator ()(size_t sample_id) {
76
+ auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id];
77
+ for (size_t x_offset = sample_id * num_classes_;
78
+ x_offset < (sample_id + 1 ) * num_classes_; ++x_offset) {
79
+ dx_[x_offset] = x_offset != x_is_true_offset
80
+ ? static_cast <T>(0 )
81
+ : -dy_[sample_id] / x_[x_offset];
82
+ }
83
+ }
84
+
85
+ private:
86
+ T* dx_;
87
+ const T* dy_;
88
+ const T* x_;
89
+ const int64_t * label_;
90
+ size_t num_classes_;
91
+ };
92
+
93
+ template <typename DeviceContext, typename T>
47
94
class CrossEntropyGradientOpKernel : public framework ::OpKernel<T> {
48
95
public:
49
96
void Compute (const framework::ExecutionContext& ctx) const override {
50
- PADDLE_ENFORCE (platform::is_cpu_place (ctx.GetPlace ()),
51
- " This kernel only runs on CPU." );
52
- const Tensor* x = ctx.Input <Tensor>(" X" );
53
- const Tensor* dy = ctx.Input <Tensor>(framework::GradVarName (" Y" ));
54
- const Tensor* label = ctx.Input <Tensor>(" Label" );
55
- Tensor* dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
56
- T* dx_data = dx->mutable_data <T>(ctx.GetPlace ());
97
+ auto * x = ctx.Input <Tensor>(" X" );
98
+ auto * dy = ctx.Input <Tensor>(framework::GradVarName (" Y" ));
99
+ auto * label = ctx.Input <Tensor>(" Label" );
100
+ auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
101
+ auto * dx_data = dx->mutable_data <T>(ctx.GetPlace ());
57
102
58
103
int64_t class_num = x->dims ()[1 ];
59
104
if (ctx.Attr <bool >(" soft_label" )) {
60
- auto x_mat = EigenMatrix<T>::From (*x);
61
- auto dy_mat = EigenMatrix<T>::From (*dy);
62
- auto lbl_mat = EigenMatrix<T>::From (*label);
63
- auto dx_mat = EigenMatrix<T>::From (*dx);
64
-
65
- dx_mat.device (*ctx.template device_context <platform::CPUDeviceContext>()
66
- .eigen_device ()) =
67
- -(lbl_mat *
68
- dy_mat.broadcast (Eigen::DSizes<int64_t , 2 >(1 , class_num)) / x_mat);
105
+ XeSoftlabelGradFunctor<T> functor (dx_data, dy->data <T>(), x->data <T>(),
106
+ label->data <T>(),
107
+ static_cast <size_t >(class_num));
108
+ platform::ForRange<DeviceContext> for_range (
109
+ ctx.template device_context <DeviceContext>(),
110
+ static_cast <size_t >(dx->numel ()));
111
+ for_range (functor);
69
112
} else {
70
- int64_t batch_size = x->dims ()[0 ];
71
- const T* dy_data = dy->data <T>();
72
- const T* x_data = x->data <T>();
73
- const int64_t * label_data = label->data <int64_t >();
74
-
75
- math::SetConstant<platform::CPUDeviceContext, T> functor;
76
- functor (ctx.template device_context <platform::CPUDeviceContext>(), dx, 0 );
77
-
78
- for (int64_t i = 0 ; i < batch_size; ++i) {
79
- PADDLE_ASSERT (label_data[i] >= 0 || label_data[i] < class_num);
80
- int64_t index = i * class_num + label_data[i];
81
- dx_data[index] = math::TolerableValue<T>()(-dy_data[i] / x_data[index]);
82
- }
113
+ XeGradFunctor<T> functor (dx_data, dy->data <T>(), x->data <T>(),
114
+ label->data <int64_t >(),
115
+ static_cast <size_t >(class_num));
116
+ platform::ForRange<DeviceContext> for_range (
117
+ ctx.template device_context <DeviceContext>(),
118
+ static_cast <size_t >(dy->numel ()));
119
+ for_range (functor);
83
120
}
84
121
}
85
122
};
0 commit comments