@@ -150,7 +150,7 @@ void* paddle_get_scorer(double alpha,
150150}
151151
152152
153- int beam_decode_with_given_state (at::Tensor th_probs,
153+ torch::Tensor beam_decode_with_given_state (at::Tensor th_probs,
154154 at::Tensor th_seq_lens,
155155 size_t num_processes,
156156 std::vector<void *> &states,
@@ -185,6 +185,31 @@ int beam_decode_with_given_state(at::Tensor th_probs,
185185 std::vector<std::vector<std::pair<double , Output>>> batch_results =
186186 ctc_beam_search_decoder_batch_with_states (inputs, num_processes, states, is_eos_s);
187187 auto outputs_accessor = th_output.accessor <int , 3 >();
188+
189+ int max_result_size = 0 ;
190+ int max_output_tokens_size = 0 ;
191+ for (int b = 0 ; b < batch_results.size (); ++b){
192+ std::vector<std::pair<double , Output>> results = batch_results[b];
193+ if (batch_results[b].size () > max_result_size) {
194+ max_result_size = batch_results[b].size ();
195+ }
196+ for (int p = 0 ; p < results.size ();++p){
197+ std::pair<double , Output> n_path_result = results[p];
198+ Output output = n_path_result.second ;
199+ std::vector<int > output_tokens = output.tokens ;
200+ std::vector<int > output_timesteps = output.timesteps ;
201+
202+ if (output_tokens.size () > max_output_tokens_size) {
203+ max_output_tokens_size = output_tokens.size ();
204+ }
205+ }
206+ }
207+
208+ torch::Tensor tensor = torch::randint (1 , {batch_results.size (), max_result_size, max_output_tokens_size});
209+ // cout << batch_results.size() << endl;
210+ // cout << max_result_size << endl;
211+ // cout << max_output_tokens_size << endl;
212+
188213 auto timesteps_accessor = th_timesteps.accessor <int , 3 >();
189214 auto scores_accessor = th_scores.accessor <float , 2 >();
190215 auto out_length_accessor = th_out_length.accessor <int , 2 >();
@@ -205,11 +230,11 @@ int beam_decode_with_given_state(at::Tensor th_probs,
205230 std::vector<int > output_tokens = output.tokens ;
206231 std::vector<int > output_timesteps = output.timesteps ;
207232 for (int t = 0 ; t < output_tokens.size (); ++t) {
208- if (t < outputs_accessor .size (2 )) {
209- outputs_accessor [b][p][t] = output_tokens[t]; // fill output tokens
233+ if (t < tensor .size (2 )) {
234+ tensor [b][p][t] = output_tokens[t]; // fill output tokens
210235 }
211236 // else {
212- // std::cerr << "Unsupported size: t >= outputs_accessor .size(2)\n";
237+ // std::cerr << "Unsupported size: t >= tensor .size(2)\n";
213238 // }
214239
215240 if (t < timesteps_accessor.size (2 )) {
@@ -225,12 +250,13 @@ int beam_decode_with_given_state(at::Tensor th_probs,
225250 }
226251 }
227252
228-
229- return 1 ;
253+ // torch::Tensor int_tensor = tensor.to(torch::kInt32);
254+
255+ return tensor;
230256}
231257
232258
233- int paddle_beam_decode_with_given_state (at::Tensor th_probs,
259+ torch::Tensor paddle_beam_decode_with_given_state (at::Tensor th_probs,
234260 at::Tensor th_seq_lens,
235261 size_t num_processes,
236262 std::vector<void *> states,
0 commit comments