@@ -192,49 +192,29 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
192
192
193
193
std::string ItemToString (const BeamSearch::Item& item);
194
194
195
- class BeamSearchOp : public framework ::OperatorBase {
195
+ template <typename DeviceContext, typename T>
196
+ class BeamSearchOpKernel : public framework ::OpKernel<T> {
196
197
public:
197
- BeamSearchOp (const std::string& type,
198
- const framework::VariableNameMap& inputs,
199
- const framework::VariableNameMap& outputs,
200
- const framework::AttributeMap& attrs)
201
- : OperatorBase(type, inputs, outputs, attrs) {}
202
-
203
- BeamSearchOp (const BeamSearchOp& o)
204
- : framework::OperatorBase(
205
- static_cast <const framework::OperatorBase&>(o)) {
206
- PADDLE_THROW (" Not Implemented" );
207
- }
208
-
209
- private:
210
- void RunImpl (const framework::Scope& scope,
211
- const platform::Place& dev_place) const override {
212
- auto ids_var = scope.FindVar (Input (" ids" ));
213
- auto scores_var = scope.FindVar (Input (" scores" ));
214
- auto pre_ids_var = scope.FindVar (Input (" pre_ids" ));
198
+ void Compute (const framework::ExecutionContext& context) const override {
199
+ auto * ids_var = context.Input <framework::LoDTensor>(" ids" );
200
+ auto * scores_var = context.Input <framework::LoDTensor>(" scores" );
201
+ auto * pre_ids_var = context.Input <framework::LoDTensor>(" pre_ids" );
215
202
PADDLE_ENFORCE_NOT_NULL (ids_var);
216
203
PADDLE_ENFORCE_NOT_NULL (scores_var);
217
204
PADDLE_ENFORCE_NOT_NULL (pre_ids_var);
218
205
219
- auto & ids = ids_var->Get <framework::LoDTensor>();
220
- auto & scores = scores_var->Get <framework::LoDTensor>();
221
- auto & pre_ids = pre_ids_var->Get <framework::LoDTensor>();
222
- size_t level = Attr<int >(" level" );
223
- size_t beam_size = Attr<int >(" beam_size" );
224
- int end_id = Attr<int >(" end_id" );
225
- BeamSearch alg (ids, scores, level, beam_size, end_id);
226
-
227
- auto selected_ids_var = scope.FindVar (Output (" selected_ids" ));
228
- auto selected_scores_var = scope.FindVar (Output (" selected_scores" ));
206
+ size_t level = context.Attr <int >(" level" );
207
+ size_t beam_size = context.Attr <int >(" beam_size" );
208
+ int end_id = context.Attr <int >(" end_id" );
209
+ BeamSearch alg (*ids_var, *scores_var, level, beam_size, end_id);
210
+ auto selected_ids_var =
211
+ context.Output <framework::LoDTensor>(" selected_ids" );
212
+ auto selected_scores_var =
213
+ context.Output <framework::LoDTensor>(" selected_scores" );
229
214
PADDLE_ENFORCE_NOT_NULL (selected_ids_var);
230
215
PADDLE_ENFORCE_NOT_NULL (selected_scores_var);
231
- auto & selected_ids_tensor =
232
- *selected_ids_var->GetMutable <framework::LoDTensor>();
233
- auto & selected_scores_tensor =
234
- *selected_scores_var->GetMutable <framework::LoDTensor>();
235
- alg (pre_ids, &selected_ids_tensor, &selected_scores_tensor);
216
+ alg (*pre_ids_var, selected_ids_var, selected_scores_var);
236
217
}
237
218
};
238
-
239
219
} // namespace operators
240
220
} // namespace paddle
0 commit comments