@@ -150,13 +150,11 @@ void* paddle_get_scorer(double alpha,
150150}
151151
152152
153- torch::Tensor beam_decode_with_given_state (at::Tensor th_probs,
153+ std::pair< torch::Tensor, torch::Tensor> beam_decode_with_given_state (at::Tensor th_probs,
154154 at::Tensor th_seq_lens,
155155 size_t num_processes,
156156 std::vector<void *> &states,
157157 const std::vector<bool > &is_eos_s,
158- at::Tensor th_output,
159- at::Tensor th_timesteps,
160158 at::Tensor th_scores,
161159 at::Tensor th_out_length)
162160{
@@ -184,7 +182,6 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
184182
185183 std::vector<std::vector<std::pair<double , Output>>> batch_results =
186184 ctc_beam_search_decoder_batch_with_states (inputs, num_processes, states, is_eos_s);
187- auto outputs_accessor = th_output.accessor <int , 3 >();
188185
189186 int max_result_size = 0 ;
190187 int max_output_tokens_size = 0 ;
@@ -197,20 +194,20 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
197194 std::pair<double , Output> n_path_result = results[p];
198195 Output output = n_path_result.second ;
199196 std::vector<int > output_tokens = output.tokens ;
200- std::vector<int > output_timesteps = output.timesteps ;
201197
202198 if (output_tokens.size () > max_output_tokens_size) {
203199 max_output_tokens_size = output_tokens.size ();
204200 }
205201 }
206202 }
207203
208- torch::Tensor tensor = torch::randint (1 , {batch_results.size (), max_result_size, max_output_tokens_size});
204+ torch::Tensor output_tokens_tensor = torch::randint (1 , {batch_results.size (), max_result_size, max_output_tokens_size});
205+ torch::Tensor output_timesteps_tensor = torch::randint (1 , {batch_results.size (), max_result_size, max_output_tokens_size});
206+
209207 // cout << batch_results.size() << endl;
210208 // cout << max_result_size << endl;
211209 // cout << max_output_tokens_size << endl;
212210
213- auto timesteps_accessor = th_timesteps.accessor <int , 3 >();
214211 auto scores_accessor = th_scores.accessor <float , 2 >();
215212 auto out_length_accessor = th_out_length.accessor <int , 2 >();
216213
@@ -230,20 +227,8 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
230227 std::vector<int > output_tokens = output.tokens ;
231228 std::vector<int > output_timesteps = output.timesteps ;
232229 for (int t = 0 ; t < output_tokens.size (); ++t) {
233- if (t < tensor.size (2 )) {
234- tensor[b][p][t] = output_tokens[t]; // fill output tokens
235- }
236- // else {
237- // std::cerr << "Unsupported size: t >= tensor.size(2)\n";
238- // }
239-
240- if (t < timesteps_accessor.size (2 )) {
241- timesteps_accessor[b][p][t] = output_timesteps[t];
242- }
243- // TODO: понять, почему
244- // else {
245- // std::cout << "Unsupported size: t >= timesteps_accessor.size(2)\n";
246- // }
230+ output_tokens_tensor[b][p][t] = output_tokens[t]; // fill output tokens
231+ output_timesteps_tensor[b][p][t] = output_timesteps[t];
247232 }
248233 scores_accessor[b][p] = n_path_result.first ;
249234 out_length_accessor[b][p] = output_tokens.size ();
@@ -252,22 +237,19 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
252237
253238 // torch::Tensor int_tensor = tensor.to(torch::kInt32);
254239
255- return tensor ;
240+ return {output_tokens_tensor, output_timesteps_tensor} ;
256241}
257242
258243
259- torch::Tensor paddle_beam_decode_with_given_state (at::Tensor th_probs,
244+ std::pair< torch::Tensor, torch::Tensor> paddle_beam_decode_with_given_state (at::Tensor th_probs,
260245 at::Tensor th_seq_lens,
261246 size_t num_processes,
262247 std::vector<void *> states,
263248 std::vector<bool > is_eos_s,
264- at::Tensor th_output,
265- at::Tensor th_timesteps,
266249 at::Tensor th_scores,
267250 at::Tensor th_out_length){
268251
269- return beam_decode_with_given_state (th_probs, th_seq_lens, num_processes, states,is_eos_s,
270- th_output, th_timesteps, th_scores, th_out_length);
252+ return beam_decode_with_given_state (th_probs, th_seq_lens, num_processes, states,is_eos_s, th_scores, th_out_length);
271253}
272254
273255
0 commit comments