Skip to content

Commit dd85805

Browse files
authored
Fix beam_search InferShape (#25169) (#25216)
* fix beam_search infershape, test=develop * fix beam search op unittest, test=develop
1 parent 772746c commit dd85805

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

paddle/fluid/operators/beam_search_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
9595
std::vector<std::string>({"selected_ids", "selected_scores"})) {
9696
OP_INOUT_CHECK(ctx->HasOutput(arg), "Output", arg, "BeamSeach");
9797
}
98+
auto id_dims = ctx->GetInputDim("pre_ids");
99+
ctx->SetOutputDim("selected_scores", ctx->GetInputDim("pre_scores"));
100+
ctx->SetOutputDim("selected_ids", id_dims);
101+
ctx->SetOutputDim("parent_idx", {id_dims[0]});
98102
}
99103

100104
protected:

python/paddle/fluid/tests/unittests/test_beam_search_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def setUp(self):
3838
self._create_pre_scores()
3939
self._create_scores()
4040
self._create_pre_ids()
41-
self.scope.var('selected_ids')
42-
self.scope.var('selected_scores')
43-
self.scope.var('parent_idx')
41+
self.scope.var('selected_ids').get_tensor()
42+
self.scope.var('selected_scores').get_tensor()
43+
self.scope.var('parent_idx').get_tensor()
4444

4545
def test_run(self):
4646
op = Operator(

0 commit comments

Comments
 (0)