Skip to content

Commit c93a624

Browse files
author
Qingsheng Li
authored
Merge pull request #10052 from ktlichkid/fix-10026
Added kernel to Beam Search OP
2 parents c2e4756 + 79be1bb commit c93a624

File tree

2 files changed

+44
-48
lines changed

2 files changed

+44
-48
lines changed

paddle/fluid/operators/beam_search_op.cc

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,9 @@ std::string ItemToString(const BeamSearch::Item &item) {
195195
return stream.str();
196196
}
197197

198-
class BeamSearchProtoAndCheckerMaker
199-
: public framework::OpProtoAndCheckerMaker {
198+
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
200199
public:
201-
BeamSearchProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
200+
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
202201
: OpProtoAndCheckerMaker(proto, op_checker) {
203202
// inputs and outputs stored in proto
204203
AddInput("pre_ids", "ids in previous step");
@@ -222,20 +221,32 @@ class BeamSearchProtoAndCheckerMaker
222221
}
223222
};
224223

225-
class BeamSearchInferShape : public framework::InferShapeBase {
224+
class BeamSearchOp : public framework::OperatorWithKernel {
226225
public:
227-
void operator()(framework::InferShapeContext *context) const override {
226+
using framework::OperatorWithKernel::OperatorWithKernel;
227+
228+
protected:
229+
void InferShape(framework::InferShapeContext *ctx) const override {
228230
for (const std::string &arg :
229231
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
230-
PADDLE_ENFORCE(context->HasInput(arg),
231-
"BeamSearch need input argument '%s'", arg);
232+
PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'",
233+
arg);
232234
}
233235
for (const std::string &arg :
234236
std::vector<std::string>({"selected_ids", "selected_scores"})) {
235-
PADDLE_ENFORCE(context->HasOutput(arg),
237+
PADDLE_ENFORCE(ctx->HasOutput(arg),
236238
"BeamSearch need output argument '%s'", arg);
237239
}
238240
}
241+
242+
framework::OpKernelType GetExpectedKernelType(
243+
const framework::ExecutionContext &ctx) const override {
244+
framework::OpKernelType kt = framework::OpKernelType(
245+
framework::ToDataType(
246+
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
247+
platform::CPUPlace());
248+
return kt;
249+
}
239250
};
240251

241252
class BeamSearchInferVarType : public framework::VarTypeInference {
@@ -254,8 +265,13 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
254265
} // namespace operators
255266
} // namespace paddle
256267

257-
REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
258-
paddle::operators::BeamSearchProtoAndCheckerMaker,
259-
paddle::operators::BeamSearchInferShape,
260-
paddle::operators::BeamSearchInferVarType,
261-
paddle::framework::EmptyGradOpMaker);
268+
namespace ops = paddle::operators;
269+
270+
REGISTER_OPERATOR(beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker,
271+
ops::BeamSearchInferVarType);
272+
REGISTER_OP_CPU_KERNEL(
273+
beam_search,
274+
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>,
275+
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>,
276+
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int>,
277+
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/beam_search_op.h

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -192,49 +192,29 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
192192

193193
std::string ItemToString(const BeamSearch::Item& item);
194194

195-
class BeamSearchOp : public framework::OperatorBase {
195+
template <typename DeviceContext, typename T>
196+
class BeamSearchOpKernel : public framework::OpKernel<T> {
196197
public:
197-
BeamSearchOp(const std::string& type,
198-
const framework::VariableNameMap& inputs,
199-
const framework::VariableNameMap& outputs,
200-
const framework::AttributeMap& attrs)
201-
: OperatorBase(type, inputs, outputs, attrs) {}
202-
203-
BeamSearchOp(const BeamSearchOp& o)
204-
: framework::OperatorBase(
205-
static_cast<const framework::OperatorBase&>(o)) {
206-
PADDLE_THROW("Not Implemented");
207-
}
208-
209-
private:
210-
void RunImpl(const framework::Scope& scope,
211-
const platform::Place& dev_place) const override {
212-
auto ids_var = scope.FindVar(Input("ids"));
213-
auto scores_var = scope.FindVar(Input("scores"));
214-
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
198+
void Compute(const framework::ExecutionContext& context) const override {
199+
auto* ids_var = context.Input<framework::LoDTensor>("ids");
200+
auto* scores_var = context.Input<framework::LoDTensor>("scores");
201+
auto* pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
215202
PADDLE_ENFORCE_NOT_NULL(ids_var);
216203
PADDLE_ENFORCE_NOT_NULL(scores_var);
217204
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
218205

219-
auto& ids = ids_var->Get<framework::LoDTensor>();
220-
auto& scores = scores_var->Get<framework::LoDTensor>();
221-
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
222-
size_t level = Attr<int>("level");
223-
size_t beam_size = Attr<int>("beam_size");
224-
int end_id = Attr<int>("end_id");
225-
BeamSearch alg(ids, scores, level, beam_size, end_id);
226-
227-
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
228-
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
206+
size_t level = context.Attr<int>("level");
207+
size_t beam_size = context.Attr<int>("beam_size");
208+
int end_id = context.Attr<int>("end_id");
209+
BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id);
210+
auto selected_ids_var =
211+
context.Output<framework::LoDTensor>("selected_ids");
212+
auto selected_scores_var =
213+
context.Output<framework::LoDTensor>("selected_scores");
229214
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
230215
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
231-
auto& selected_ids_tensor =
232-
*selected_ids_var->GetMutable<framework::LoDTensor>();
233-
auto& selected_scores_tensor =
234-
*selected_scores_var->GetMutable<framework::LoDTensor>();
235-
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
216+
alg(*pre_ids_var, selected_ids_var, selected_scores_var);
236217
}
237218
};
238-
239219
} // namespace operators
240220
} // namespace paddle

0 commit comments

Comments
 (0)