Skip to content

Commit bc8c7cc

Browse files
authored
Merge pull request #12114 from reyoung/feature/get_beam_search_prob_on_capi
Get BeamSearch Prob on C-API
2 parents 2cb5c35 + 08fa398 commit bc8c7cc

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

paddle/legacy/capi/Arguments.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ paddle_error paddle_arguments_get_value(paddle_arguments args,
6666
return kPD_NO_ERROR;
6767
}
6868

69+
PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args,
70+
uint64_t ID,
71+
paddle_matrix mat) {
72+
if (args == nullptr || mat == nullptr) return kPD_NULLPTR;
73+
auto m = paddle::capi::cast<paddle::capi::CMatrix>(mat);
74+
auto a = castArg(args);
75+
if (ID >= a->args.size()) return kPD_OUT_OF_RANGE;
76+
m->mat = a->args[ID].in;
77+
return kPD_NO_ERROR;
78+
}
79+
6980
paddle_error paddle_arguments_get_ids(paddle_arguments args,
7081
uint64_t ID,
7182
paddle_ivector ids) {

paddle/legacy/capi/arguments.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ PD_API paddle_error paddle_arguments_get_value(paddle_arguments args,
8787
uint64_t ID,
8888
paddle_matrix mat);
8989

90+
/**
91+
* @brief paddle_arguments_get_prob Get the prob matrix of beam search, which
92+
* slot ID is `ID`
93+
* @param [in] args arguments array
94+
* @param [in] ID array index
95+
* @param [out] mat matrix pointer
96+
* @return paddle_error
97+
*/
98+
PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args,
99+
uint64_t ID,
100+
paddle_matrix mat);
101+
90102
/**
91103
* @brief PDArgsGetIds Get the integer vector of one argument in array, which
92104
* index is `ID`.

0 commit comments

Comments
 (0)