@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/nce_op.h"
16
16
17
+ #include < string>
17
18
#include < vector>
18
19
19
20
namespace paddle {
@@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel {
25
26
public:
26
27
using framework::OperatorWithKernel::OperatorWithKernel;
27
28
28
- void InferShape (framework::InferShapeContext* ctx) const override {
29
+ void InferShape (framework::InferShapeContext * ctx) const override {
29
30
PADDLE_ENFORCE (ctx->HasInput (" Input" ));
30
31
PADDLE_ENFORCE (ctx->HasInput (" Label" ));
31
32
PADDLE_ENFORCE (ctx->HasInput (" Weight" ));
@@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel {
67
68
68
69
protected:
69
70
framework::OpKernelType GetExpectedKernelType (
70
- const framework::ExecutionContext& ctx) const override {
71
+ const framework::ExecutionContext & ctx) const override {
71
72
return framework::OpKernelType (
72
73
framework::ToDataType (ctx.Input <Tensor>(" Input" )->type ()),
73
74
platform::CPUPlace ());
@@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
101
102
.AsDispensable ();
102
103
103
104
AddInput (
104
- " CustomDistribution " ,
105
+ " CustomDistProbs " ,
105
106
" (Tensor) It is used in 'CostumDist' sampler. "
106
107
" It is a tensor with shape [num_total_classes]."
107
108
" The i-th element is the probsbility of the i-th class being sampled." )
108
109
.AsDispensable ();
110
+ AddInput (
111
+ " CustomDistAlias" ,
112
+ " (Tensor) It is used in 'CostumDist' sampler. "
113
+ " It is a tensor with shape [num_total_classes]."
114
+ " The i-th element is the probsbility of the i-th class being sampled." )
115
+ .AsDispensable ();
116
+ AddInput (
117
+ " CustomDistAliasProbs" ,
118
+ " (Tensor) It is used in 'CostumDist' sampler. "
119
+ " It is a tensor with shape [num_total_classes]."
120
+ " The i-th element is the probsbility of the i-th class being sampled." )
121
+ .AsDispensable ();
122
+
109
123
AddOutput (" Cost" ,
110
124
" (Tensor) A tensor of shape [batch_size, 1]. Cost of samples." );
111
125
AddOutput (" SampleLogits" ,
@@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
124
138
" kernel to compute grads."
125
139
" " )
126
140
.AsIntermediate ();
141
+
127
142
AddAttr<int >(" num_total_classes" ,
128
143
" Total number of classes in all samples." );
129
144
AddAttr<int >(" num_neg_samples" ,
130
145
" The number of negative classes. The default value is 10." )
131
146
.SetDefault (10 );
132
-
133
147
AddAttr<int >(" sampler" ,
134
148
" (int) Which sampler to be used to sample negative class."
135
149
" 0: Uniform; 1: LogUniform; 2: CostumDist." )
136
150
.SetDefault (0 );
137
-
138
151
AddAttr<int >(" seed" ,
139
152
" (int) The seed used in sampler. If it is 0, "
140
153
" the sampler will generate a seed randomly." )
141
154
.SetDefault (0 );
155
+ AddAttr<bool >(" is_sparse" , " (boolean, default false) Sparse update." )
156
+ .SetDefault (false );
142
157
143
158
AddAttr<std::vector<int >>(" custom_neg_classes" ,
144
159
" This attribute only be used in unitest. Classes "
@@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling.
156
171
}
157
172
};
158
173
174
+ class NCEOpGradDescMaker : public framework ::DefaultGradOpDescMaker<true > {
175
+ using ::paddle::framework::DefaultGradOpDescMaker<
176
+ true >::DefaultGradOpDescMaker;
177
+
178
+ protected:
179
+ virtual std::string GradOpType () const { return " nce_grad" ; }
180
+ };
181
+
159
182
class NCEOpGrad : public framework ::OperatorWithKernel {
160
183
public:
161
184
using framework::OperatorWithKernel::OperatorWithKernel;
162
185
163
- void InferShape (framework::InferShapeContext* ctx) const override {
186
+ void InferShape (framework::InferShapeContext * ctx) const override {
164
187
PADDLE_ENFORCE (ctx->HasInput (" Input" ));
165
188
PADDLE_ENFORCE (ctx->HasInput (" Weight" ));
166
189
PADDLE_ENFORCE (ctx->HasInput (" Cost" ));
@@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel {
190
213
191
214
protected:
192
215
framework::OpKernelType GetExpectedKernelType (
193
- const framework::ExecutionContext& ctx) const override {
216
+ const framework::ExecutionContext & ctx) const override {
194
217
return framework::OpKernelType (
195
218
framework::ToDataType (ctx.Input <Tensor>(" Input" )->type ()),
196
219
platform::CPUPlace ());
197
220
}
198
221
};
199
222
223
+ class NCEOpGradVarTypeInference : public framework ::VarTypeInference {
224
+ public:
225
+ void operator ()(const framework::OpDesc &op_desc,
226
+ framework::BlockDesc *block) const override {
227
+ auto weight_grad = op_desc.Output (framework::GradVarName (" Weight" )).front ();
228
+ auto bias_grad = op_desc.Output (framework::GradVarName (" Bias" )).front ();
229
+
230
+ auto attr = op_desc.GetAttr (" is_sparse" );
231
+ bool is_sparse = boost::get<bool >(attr);
232
+ if (is_sparse) {
233
+ VLOG (30 ) << " nce_op_grad op " << weight_grad << " and " << bias_grad
234
+ << " is set to SelectedRows" ;
235
+ block->Var (weight_grad)
236
+ ->SetType (framework::proto::VarType::SELECTED_ROWS);
237
+ block->Var (bias_grad)->SetType (framework::proto::VarType::SELECTED_ROWS);
238
+ } else {
239
+ VLOG (30 ) << " nce_op_grad op " << weight_grad << " and " << bias_grad
240
+ << " is set to LoDTensor" ;
241
+ block->Var (weight_grad)->SetType (framework::proto::VarType::LOD_TENSOR);
242
+ block->Var (bias_grad)->SetType (framework::proto::VarType::LOD_TENSOR);
243
+ }
244
+ block->Var (weight_grad)->SetDataType (block->Var (" Input" )->GetDataType ());
245
+ block->Var (bias_grad)->SetDataType (block->Var (" Input" )->GetDataType ());
246
+ }
247
+ };
248
+
200
249
} // namespace operators
201
250
} // namespace paddle
202
251
203
252
namespace ops = paddle::operators;
204
- REGISTER_OPERATOR (nce, ops::NCEOp, ops::NCEOpMaker,
205
- paddle::framework::DefaultGradOpDescMaker<true >);
206
- REGISTER_OPERATOR (nce_grad, ops::NCEOpGrad);
253
+ REGISTER_OPERATOR (nce, ops::NCEOp, ops::NCEOpGradDescMaker, ops::NCEOpMaker);
254
+ REGISTER_OPERATOR (nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference);
207
255
REGISTER_OP_CPU_KERNEL (nce, ops::NCEKernel<paddle::platform::CPUPlace, float >,
208
256
ops::NCEKernel<paddle::platform::CPUPlace, double >);
209
257
REGISTER_OP_CPU_KERNEL (nce_grad,
0 commit comments