@@ -19,9 +19,11 @@ limitations under the License. */
19
19
#include " paddle/fluid/framework/mixed_vector.h"
20
20
#include " paddle/fluid/framework/op_registry.h"
21
21
#include " paddle/fluid/operators/clip_op.h"
22
+ #include " paddle/fluid/operators/detail/safe_ref.h"
22
23
#include " paddle/fluid/operators/math/math_function.h"
23
24
#include " paddle/fluid/operators/math/matrix_bit_code.h"
24
25
#include " paddle/fluid/platform/transform.h"
26
+
25
27
namespace paddle {
26
28
namespace operators {
27
29
@@ -30,31 +32,26 @@ template <typename T, int MajorType = Eigen::RowMajor,
30
32
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
31
33
using platform::Transform;
32
34
33
- std::vector<int64_t > cal_rows (const framework::LoDTensor& path) {
34
- std::set<int64_t > tmp;
35
- std::vector<int64_t > rows;
36
- for (size_t i = 0 ; i < static_cast <size_t >(path.dims ()[0 ]); i++) {
37
- for (size_t j = 0 ; j < static_cast <size_t >(path.dims ()[1 ]); j++) {
38
- int64_t temp =
39
- path.data <int64_t >()[i * static_cast <size_t >(path.dims ()[1 ]) + j];
40
- if (temp >= 0 ) {
41
- tmp.insert (temp);
42
- }
35
+ static std::vector<int64_t > PathToRows (const framework::LoDTensor& path) {
36
+ std::set<int64_t > rows;
37
+ for (int64_t i = 0 ; i < path.numel (); ++i) {
38
+ int64_t row = path.data <int64_t >()[i];
39
+ if (row < 0 ) {
40
+ continue ;
43
41
}
42
+ rows.emplace (row);
44
43
}
45
- rows.assign (tmp.begin (), tmp.end ());
46
- return rows;
44
+ return std::vector<int64_t >(rows.begin (), rows.end ());
47
45
}
48
-
49
46
template <typename DeviceContext, typename T>
50
47
class HierarchicalSigmoidOpKernel : public framework ::OpKernel<T> {
51
48
public:
52
49
void Compute (const framework::ExecutionContext& ctx) const override {
53
- auto * in = ctx.Input <framework::LoDTensor>(" X" );
54
- auto * w = ctx.Input <framework::LoDTensor>(" W" );
50
+ auto in = detail::Ref ( ctx.Input <framework::LoDTensor>(" X" ) );
51
+ auto w = detail::Ref ( ctx.Input <framework::LoDTensor>(" W" ) );
55
52
auto * path = ctx.Input <framework::LoDTensor>(" PTable" );
56
- auto * code = ctx.Input <framework::LoDTensor>(" PCode " );
57
- auto * label = ctx.Input <framework::LoDTensor>(" Label" );
53
+ auto * code = ctx.Input <framework::LoDTensor>(" PathCode " );
54
+ auto label = detail::Ref ( ctx.Input <framework::LoDTensor>(" Label" ) );
58
55
auto * bias = ctx.Input <framework::LoDTensor>(" Bias" );
59
56
auto * out = ctx.Output <framework::LoDTensor>(" Out" );
60
57
auto * pre_out = ctx.Output <framework::LoDTensor>(" PreOut" );
@@ -65,7 +62,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
65
62
}
66
63
int64_t code_length =
67
64
path ? path->dims ()[1 ] : math::FindLastSet (num_classes - 1 );
68
- int64_t batch_size = in-> dims ()[0 ];
65
+ int64_t batch_size = in. dims ()[0 ];
69
66
framework::LoDTensor sum;
70
67
auto & dev_ctx = ctx.template device_context <DeviceContext>();
71
68
auto * pre_out_data = pre_out->mutable_data <T>(
@@ -81,10 +78,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
81
78
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
82
79
if (!is_custom) {
83
80
bit_code.reset (new math::MatrixBitCodeFunctor<T>(num_classes,
84
- label-> data <int64_t >()));
81
+ label. data <int64_t >()));
85
82
} else {
86
- bit_code.reset (new math::MatrixBitCodeFunctor<T>(path, code,
87
- label-> data <int64_t >()));
83
+ bit_code.reset (new math::MatrixBitCodeFunctor<T>(* path, * code,
84
+ label. data <int64_t >()));
88
85
}
89
86
90
87
std::vector<int64_t > sum_dims ({batch_size, 1UL });
@@ -95,7 +92,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
95
92
if (bias) {
96
93
bit_code->Add (*bias, pre_out);
97
94
}
98
- bit_code->Mul (pre_out, * w, * in);
95
+ bit_code->Mul (pre_out, w, in);
99
96
// clip to [-40, 40]
100
97
Transform<DeviceContext> trans;
101
98
trans (ctx.template device_context <DeviceContext>(), pre_out_data,
@@ -117,23 +114,23 @@ template <typename DeviceContext, typename T>
117
114
class HierarchicalSigmoidGradOpKernel : public framework ::OpKernel<T> {
118
115
public:
119
116
void Compute (const framework::ExecutionContext& ctx) const override {
120
- auto * in = ctx.Input <framework::LoDTensor>(" X" );
121
- auto * w = ctx.Input <framework::LoDTensor>(" W" );
117
+ auto in = detail::Ref ( ctx.Input <framework::LoDTensor>(" X" ) );
118
+ auto w = detail::Ref ( ctx.Input <framework::LoDTensor>(" W" ) );
122
119
auto * path = ctx.Input <framework::LoDTensor>(" PTable" );
123
- auto * code = ctx.Input <framework::LoDTensor>(" PCode " );
120
+ auto * code = ctx.Input <framework::LoDTensor>(" PathCode " );
124
121
auto * bias = ctx.Input <framework::LoDTensor>(" Bias" );
125
122
auto * in_grad =
126
123
ctx.Output <framework::LoDTensor>(framework::GradVarName (" X" ));
127
124
bool is_sparse = ctx.Attr <bool >(" is_sparse" );
128
125
auto & dev_ctx = ctx.template device_context <DeviceContext>();
129
126
math::SetConstant<DeviceContext, T> zero;
130
- auto * label = ctx.Input <framework::LoDTensor>(" Label" );
131
- auto * pre_out = ctx.Input <framework::LoDTensor>(" PreOut" );
132
- auto * out_grad =
133
- ctx.Input <framework::LoDTensor>(framework::GradVarName (" Out" ));
127
+ auto label = detail::Ref ( ctx.Input <framework::LoDTensor>(" Label" ) );
128
+ auto pre_out = detail::Ref ( ctx.Input <framework::LoDTensor>(" PreOut" ) );
129
+ auto out_grad = detail::Ref (
130
+ ctx.Input <framework::LoDTensor>(framework::GradVarName (" Out" ))) ;
134
131
framework::LoDTensor pre_out_grad;
135
132
136
- pre_out_grad.mutable_data <T>(pre_out-> dims (), ctx.GetPlace ());
133
+ pre_out_grad.mutable_data <T>(pre_out. dims (), ctx.GetPlace ());
137
134
in_grad->mutable_data <T>(ctx.GetPlace ());
138
135
zero (dev_ctx, in_grad, static_cast <T>(0.0 ));
139
136
@@ -147,16 +144,16 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
147
144
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
148
145
if (!is_custom) {
149
146
bit_code.reset (new math::MatrixBitCodeFunctor<T>(num_classes,
150
- label-> data <int64_t >()));
147
+ label. data <int64_t >()));
151
148
} else {
152
- bit_code.reset (new math::MatrixBitCodeFunctor<T>(path, code,
153
- label-> data <int64_t >()));
149
+ bit_code.reset (new math::MatrixBitCodeFunctor<T>(* path, * code,
150
+ label. data <int64_t >()));
154
151
}
155
152
156
153
auto & place = *ctx.template device_context <DeviceContext>().eigen_device ();
157
- auto pre_out_mat = EigenMatrix<T>::From (* pre_out);
154
+ auto pre_out_mat = EigenMatrix<T>::From (pre_out);
158
155
auto pre_out_grad_mat = EigenMatrix<T>::From (pre_out_grad);
159
- auto out_grad_mat = EigenMatrix<T>::From (* out_grad);
156
+ auto out_grad_mat = EigenMatrix<T>::From (out_grad);
160
157
161
158
Eigen::array<int , 2 > bcast{1 , static_cast <int >(pre_out_grad.dims ()[1 ])};
162
159
@@ -181,17 +178,17 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
181
178
ctx.Output <framework::LoDTensor>(framework::GradVarName (" W" ));
182
179
w_grad->mutable_data <T>(ctx.GetPlace ());
183
180
zero (dev_ctx, w_grad, static_cast <T>(0.0 ));
184
- bit_code->MulGradWeight (pre_out_grad, w_grad, * in);
181
+ bit_code->MulGradWeight (pre_out_grad, w_grad, in);
185
182
} else {
186
- framework::Vector<int64_t > real_rows = cal_rows (*path);
183
+ framework::Vector<int64_t > real_rows = PathToRows (*path);
187
184
auto * w_grad =
188
185
ctx.Output <framework::SelectedRows>(framework::GradVarName (" W" ));
189
186
w_grad->set_rows (real_rows);
190
187
// Build a map of id -> row_index to speed up finding the index of one id
191
188
w_grad->SyncIndex ();
192
- w_grad->set_height (w-> dims ()[0 ]);
189
+ w_grad->set_height (w. dims ()[0 ]);
193
190
auto * w_grad_value = w_grad->mutable_value ();
194
- framework::DDim temp_dim (w-> dims ());
191
+ framework::DDim temp_dim (w. dims ());
195
192
set (temp_dim, 0 , real_rows.size ());
196
193
197
194
w_grad_value->mutable_data <T>(temp_dim, ctx.GetPlace ());
@@ -211,9 +208,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
211
208
zero (dev_ctx, bias_grad_value, static_cast <T>(0.0 ));
212
209
bit_code->AddGrad (pre_out_grad, bias_grad);
213
210
}
214
- bit_code->MulGradWeight (pre_out_grad, w_grad, * in);
211
+ bit_code->MulGradWeight (pre_out_grad, w_grad, in);
215
212
}
216
- bit_code->MulGradError (pre_out_grad, * w, in_grad);
213
+ bit_code->MulGradError (pre_out_grad, w, in_grad);
217
214
}
218
215
};
219
216
0 commit comments