Skip to content

Commit 9c86510

Browse files
author
Prikhodko Stanislav
committed
growing seq_len bug fix
1 parent 08ff72a commit 9c86510

File tree

4 files changed

+63
-12
lines changed

4 files changed

+63
-12
lines changed

ctcdecode/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def decode(self, probs, states, is_eos_s, seq_lens=None):
225225
out_seq_len = torch.zeros(batch_size, self._beam_width).cpu().int()
226226

227227
decode_fn = ctc_decode.paddle_beam_decode_with_given_state
228-
decode_fn(
228+
res = decode_fn(
229229
probs,
230230
seq_lens,
231231
self._num_processes,
@@ -234,10 +234,11 @@ def decode(self, probs, states, is_eos_s, seq_lens=None):
234234
output,
235235
timesteps,
236236
scores,
237-
out_seq_len,
237+
out_seq_len
238238
)
239+
res = res.int()
239240

240-
return output, scores, timesteps, out_seq_len
241+
return res, scores, timesteps, out_seq_len
241242

242243
def character_based(self):
243244
return ctc_decode.is_character_based(self._scorer) if self._scorer else None

ctcdecode/src/binding.cpp

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ void* paddle_get_scorer(double alpha,
150150
}
151151

152152

153-
int beam_decode_with_given_state(at::Tensor th_probs,
153+
torch::Tensor beam_decode_with_given_state(at::Tensor th_probs,
154154
at::Tensor th_seq_lens,
155155
size_t num_processes,
156156
std::vector<void*> &states,
@@ -185,6 +185,31 @@ int beam_decode_with_given_state(at::Tensor th_probs,
185185
std::vector<std::vector<std::pair<double, Output>>> batch_results =
186186
ctc_beam_search_decoder_batch_with_states(inputs, num_processes, states, is_eos_s);
187187
auto outputs_accessor = th_output.accessor<int, 3>();
188+
189+
int max_result_size = 0;
190+
int max_output_tokens_size = 0;
191+
for (int b = 0; b < batch_results.size(); ++b){
192+
std::vector<std::pair<double, Output>> results = batch_results[b];
193+
if (batch_results[b].size() > max_result_size) {
194+
max_result_size = batch_results[b].size();
195+
}
196+
for (int p = 0; p < results.size();++p){
197+
std::pair<double, Output> n_path_result = results[p];
198+
Output output = n_path_result.second;
199+
std::vector<int> output_tokens = output.tokens;
200+
std::vector<int> output_timesteps = output.timesteps;
201+
202+
if (output_tokens.size() > max_output_tokens_size) {
203+
max_output_tokens_size = output_tokens.size();
204+
}
205+
}
206+
}
207+
208+
torch::Tensor tensor = torch::randint(1, {batch_results.size(), max_result_size, max_output_tokens_size});
209+
// cout << batch_results.size() << endl;
210+
// cout << max_result_size << endl;
211+
// cout << max_output_tokens_size << endl;
212+
188213
auto timesteps_accessor = th_timesteps.accessor<int, 3>();
189214
auto scores_accessor = th_scores.accessor<float, 2>();
190215
auto out_length_accessor = th_out_length.accessor<int, 2>();
@@ -205,11 +230,11 @@ int beam_decode_with_given_state(at::Tensor th_probs,
205230
std::vector<int> output_tokens = output.tokens;
206231
std::vector<int> output_timesteps = output.timesteps;
207232
for (int t = 0; t < output_tokens.size(); ++t) {
208-
if (t < outputs_accessor.size(2)) {
209-
outputs_accessor[b][p][t] = output_tokens[t]; // fill output tokens
233+
if (t < tensor.size(2)) {
234+
tensor[b][p][t] = output_tokens[t]; // fill output tokens
210235
}
211236
// else {
212-
// std::cerr << "Unsupported size: t >= outputs_accessor.size(2)\n";
237+
// std::cerr << "Unsupported size: t >= tensor.size(2)\n";
213238
// }
214239

215240
if (t < timesteps_accessor.size(2)) {
@@ -225,12 +250,13 @@ int beam_decode_with_given_state(at::Tensor th_probs,
225250
}
226251
}
227252

228-
229-
return 1;
253+
// torch::Tensor int_tensor = tensor.to(torch::kInt32);
254+
255+
return tensor;
230256
}
231257

232258

233-
int paddle_beam_decode_with_given_state(at::Tensor th_probs,
259+
torch::Tensor paddle_beam_decode_with_given_state(at::Tensor th_probs,
234260
at::Tensor th_seq_lens,
235261
size_t num_processes,
236262
std::vector<void*> states,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _single_compile(obj):
131131

132132
setup(
133133
name="ctcdecode",
134-
version="1.1.0",
134+
version="1.0.3",
135135
description="CTC Decoder for PyTorch based on Paddle Paddle's implementation",
136136
url="https://github.com/parlance/ctcdecode",
137137
author="Ryan Leary",

tests/test_decode.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,36 @@ def test_online_decoder_decoding_with_two_calls_no_lm(self):
179179
)
180180

181181
del state1, state2
182-
182+
size = beam_results.shape
183183
output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0])
184184
output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0])
185185

186186
self.assertEqual(output_str1, self.beam_search_result[0])
187187
self.assertEqual(output_str2, self.beam_search_result[1])
188+
189+
def test_online_decoder_decoding_with_a_lot_calls_no_lm_check_size(self):
190+
decoder = ctcdecode.OnlineCTCBeamDecoder(
191+
self.vocab_list,
192+
beam_width=self.beam_size,
193+
blank_id=self.vocab_list.index("_"),
194+
log_probs_input=True,
195+
num_processes=24,
196+
)
197+
state1 = ctcdecode.DecoderState(decoder)
198+
199+
probs_seq = torch.FloatTensor([self.probs_seq1]).log()
200+
201+
for i in range(1000):
202+
beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
203+
probs_seq, [state1], [False, False]
204+
)
205+
206+
beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
207+
probs_seq, [state1], [True, True]
208+
)
209+
210+
del state1
211+
self.assertGreaterEqual(beam_results.shape[2], out_seq_len.max())
188212

189213

190214
if __name__ == "__main__":

0 commit comments

Comments
 (0)