Skip to content

Commit 2f49fae

Browse files
author
Prikhodko Stanislav
committed
timesteps len bug fix
1 parent 9c86510 commit 2f49fae

File tree

2 files changed

+13
-34
lines changed

2 files changed

+13
-34
lines changed

ctcdecode/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,26 +219,23 @@ def decode(self, probs, states, is_eos_s, seq_lens=None):
219219
seq_lens = torch.IntTensor(batch_size).fill_(max_seq_len)
220220
else:
221221
seq_lens = seq_lens.cpu().int()
222-
output = torch.IntTensor(batch_size, self._beam_width, max_seq_len).cpu().int()
223-
timesteps = torch.IntTensor(batch_size, self._beam_width, max_seq_len).cpu().int()
224222
scores = torch.FloatTensor(batch_size, self._beam_width).cpu().float()
225223
out_seq_len = torch.zeros(batch_size, self._beam_width).cpu().int()
226224

227225
decode_fn = ctc_decode.paddle_beam_decode_with_given_state
228-
res = decode_fn(
226+
res_beam_results, res_timesteps = decode_fn(
229227
probs,
230228
seq_lens,
231229
self._num_processes,
232230
[state.state for state in states],
233231
is_eos_s,
234-
output,
235-
timesteps,
236232
scores,
237233
out_seq_len
238234
)
239-
res = res.int()
235+
res_beam_results = res_beam_results.int()
236+
res_timesteps = res_timesteps.int()
240237

241-
return res, scores, timesteps, out_seq_len
238+
return res_beam_results, scores, res_timesteps, out_seq_len
242239

243240
def character_based(self):
244241
return ctc_decode.is_character_based(self._scorer) if self._scorer else None

ctcdecode/src/binding.cpp

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,11 @@ void* paddle_get_scorer(double alpha,
150150
}
151151

152152

153-
torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
153+
std::pair<torch::Tensor, torch::Tensor> beam_decode_with_given_state(at::Tensor th_probs,
154154
at::Tensor th_seq_lens,
155155
size_t num_processes,
156156
std::vector<void*> &states,
157157
const std::vector<bool> &is_eos_s,
158-
at::Tensor th_output,
159-
at::Tensor th_timesteps,
160158
at::Tensor th_scores,
161159
at::Tensor th_out_length)
162160
{
@@ -184,7 +182,6 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
184182

185183
std::vector<std::vector<std::pair<double, Output>>> batch_results =
186184
ctc_beam_search_decoder_batch_with_states(inputs, num_processes, states, is_eos_s);
187-
auto outputs_accessor = th_output.accessor<int, 3>();
188185

189186
int max_result_size = 0;
190187
int max_output_tokens_size = 0;
@@ -197,20 +194,20 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
197194
std::pair<double, Output> n_path_result = results[p];
198195
Output output = n_path_result.second;
199196
std::vector<int> output_tokens = output.tokens;
200-
std::vector<int> output_timesteps = output.timesteps;
201197

202198
if (output_tokens.size() > max_output_tokens_size) {
203199
max_output_tokens_size = output_tokens.size();
204200
}
205201
}
206202
}
207203

208-
torch::Tensor tensor = torch::randint(1, {batch_results.size(), max_result_size, max_output_tokens_size});
204+
torch::Tensor output_tokens_tensor = torch::randint(1, {batch_results.size(), max_result_size, max_output_tokens_size});
205+
torch::Tensor output_timesteps_tensor = torch::randint(1, {batch_results.size(), max_result_size, max_output_tokens_size});
206+
209207
// cout << batch_results.size() << endl;
210208
// cout << max_result_size << endl;
211209
// cout << max_output_tokens_size << endl;
212210

213-
auto timesteps_accessor = th_timesteps.accessor<int, 3>();
214211
auto scores_accessor = th_scores.accessor<float, 2>();
215212
auto out_length_accessor = th_out_length.accessor<int, 2>();
216213

@@ -230,20 +227,8 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
230227
std::vector<int> output_tokens = output.tokens;
231228
std::vector<int> output_timesteps = output.timesteps;
232229
for (int t = 0; t < output_tokens.size(); ++t) {
233-
if (t < tensor.size(2)) {
234-
tensor[b][p][t] = output_tokens[t]; // fill output tokens
235-
}
236-
// else {
237-
// std::cerr << "Unsupported size: t >= tensor.size(2)\n";
238-
// }
239-
240-
if (t < timesteps_accessor.size(2)) {
241-
timesteps_accessor[b][p][t] = output_timesteps[t];
242-
}
243-
// TODO: понять, почему
244-
// else {
245-
// std::cout << "Unsupported size: t >= timesteps_accessor.size(2)\n";
246-
// }
230+
output_tokens_tensor[b][p][t] = output_tokens[t]; // fill output tokens
231+
output_timesteps_tensor[b][p][t] = output_timesteps[t];
247232
}
248233
scores_accessor[b][p] = n_path_result.first;
249234
out_length_accessor[b][p] = output_tokens.size();
@@ -252,22 +237,19 @@ torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
252237

253238
// torch::Tensor int_tensor = tensor.to(torch::kInt32);
254239

255-
return tensor;
240+
return {output_tokens_tensor, output_timesteps_tensor};
256241
}
257242

258243

259-
torch::Tensor paddle_beam_decode_with_given_state(at::Tensor th_probs,
244+
std::pair<torch::Tensor, torch::Tensor> paddle_beam_decode_with_given_state(at::Tensor th_probs,
260245
at::Tensor th_seq_lens,
261246
size_t num_processes,
262247
std::vector<void*> states,
263248
std::vector<bool> is_eos_s,
264-
at::Tensor th_output,
265-
at::Tensor th_timesteps,
266249
at::Tensor th_scores,
267250
at::Tensor th_out_length){
268251

269-
return beam_decode_with_given_state(th_probs, th_seq_lens, num_processes, states,is_eos_s,
270-
th_output, th_timesteps, th_scores, th_out_length);
252+
return beam_decode_with_given_state(th_probs, th_seq_lens, num_processes, states,is_eos_s, th_scores, th_out_length);
271253
}
272254

273255

0 commit comments

Comments
 (0)