Skip to content

Commit 62f3364

Browse files
author
Sean Naren
authored
Merge pull request #112 from reuben/streaming_api
Expose streaming API
2 parents 6929bf5 + ca630b4 commit 62f3364

File tree

4 files changed

+149
-39
lines changed

4 files changed

+149
-39
lines changed

ctcdecode/src/ctc_beam_search_decoder.cpp

Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,33 @@
1414

1515
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
1616

17-
std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
18-
const std::vector<std::vector<double>> &probs_seq,
19-
const std::vector<std::string> &vocabulary,
20-
size_t beam_size,
21-
double cutoff_prob,
22-
size_t cutoff_top_n,
23-
size_t blank_id,
24-
int log_input,
25-
Scorer *ext_scorer) {
26-
// dimension check
27-
size_t num_time_steps = probs_seq.size();
28-
for (size_t i = 0; i < num_time_steps; ++i) {
29-
VALID_CHECK_EQ(probs_seq[i].size(),
30-
vocabulary.size(),
31-
"The shape of probs_seq does not match with "
32-
"the shape of the vocabulary");
33-
}
34-
35-
// assign blank id
36-
// size_t blank_id = vocabulary.size();
37-
17+
DecoderState::DecoderState(const std::vector<std::string> &vocabulary,
18+
size_t beam_size,
19+
double cutoff_prob,
20+
size_t cutoff_top_n,
21+
size_t blank_id,
22+
int log_input,
23+
Scorer *ext_scorer)
24+
: abs_time_step(0)
25+
, beam_size(beam_size)
26+
, cutoff_prob(cutoff_prob)
27+
, cutoff_top_n(cutoff_top_n)
28+
, blank_id(blank_id)
29+
, log_input(log_input)
30+
, vocabulary(vocabulary)
31+
, ext_scorer(ext_scorer)
32+
{
3833
// assign space id
3934
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
40-
int space_id = it - vocabulary.begin();
4135
// if no space in vocabulary
42-
if ((size_t)space_id >= vocabulary.size()) {
36+
if (it == vocabulary.end()) {
4337
space_id = -2;
38+
} else {
39+
space_id = std::distance(vocabulary.begin(), it);
4440
}
4541

4642
// init prefixes' root
47-
PathTrie root;
4843
root.score = root.log_prob_b_prev = 0.0;
49-
std::vector<PathTrie *> prefixes;
5044
prefixes.push_back(&root);
5145

5246
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
@@ -56,9 +50,22 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
5650
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
5751
root.set_matcher(matcher);
5852
}
53+
}
54+
55+
void
56+
DecoderState::next(const std::vector<std::vector<double>> &probs_seq)
57+
{
58+
// dimension check
59+
size_t num_time_steps = probs_seq.size();
60+
for (size_t i = 0; i < num_time_steps; ++i) {
61+
VALID_CHECK_EQ(probs_seq[i].size(),
62+
vocabulary.size(),
63+
"The shape of probs_seq does not match with "
64+
"the shape of the vocabulary");
65+
}
5966

6067
// prefix search over time
61-
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
68+
for (size_t time_step = 0; time_step < num_time_steps; ++time_step, ++abs_time_step) {
6269
auto &prob = probs_seq[time_step];
6370

6471
float min_cutoff = -NUM_FLT_INF;
@@ -97,7 +104,7 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
97104
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
98105
}
99106
// get new prefix
100-
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
107+
auto prefix_new = prefix->get_path_trie(c, abs_time_step, log_prob_c);
101108

102109
if (prefix_new != nullptr) {
103110
float log_p = -NUM_FLT_INF;
@@ -147,45 +154,75 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
147154
for (size_t i = beam_size; i < prefixes.size(); ++i) {
148155
prefixes[i]->remove();
149156
}
157+
158+
prefixes.resize(beam_size);
150159
}
151160
} // end of loop over time
161+
}
162+
163+
std::vector<std::pair<double, Output>>
164+
DecoderState::decode() const
165+
{
166+
std::vector<PathTrie*> prefixes_copy = prefixes;
167+
std::unordered_map<const PathTrie*, float> scores;
168+
for (PathTrie* prefix : prefixes_copy) {
169+
scores[prefix] = prefix->score;
170+
}
152171

153172
// score the last word of each prefix that doesn't end with space
154173
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
155-
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
156-
auto prefix = prefixes[i];
174+
for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) {
175+
auto prefix = prefixes_copy[i];
157176
if (!prefix->is_empty() && prefix->character != space_id) {
158177
float score = 0.0;
159178
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
160179
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
161180
score += ext_scorer->beta;
162-
prefix->score += score;
181+
scores[prefix] += score;
163182
}
164183
}
165184
}
166185

167-
size_t num_prefixes = std::min(prefixes.size(), beam_size);
168-
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
186+
using namespace std::placeholders;
187+
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size);
188+
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes,
189+
std::bind(prefix_compare_external_scores, _1, _2, scores));
169190

170191
// compute aproximate ctc score as the return score, without affecting the
171192
// return order of decoding result. To delete when decoder gets stable.
172-
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
173-
double approx_ctc = prefixes[i]->score;
193+
for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) {
194+
double approx_ctc = scores[prefixes_copy[i]];
174195
if (ext_scorer != nullptr) {
175196
std::vector<int> output;
176197
std::vector<int> timesteps;
177-
prefixes[i]->get_path_vec(output, timesteps);
198+
prefixes_copy[i]->get_path_vec(output, timesteps);
178199
auto prefix_length = output.size();
179200
auto words = ext_scorer->split_labels(output);
180201
// remove word insert
181202
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
182203
// remove language model weight:
183204
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
184205
}
185-
prefixes[i]->approx_ctc = approx_ctc;
206+
prefixes_copy[i]->approx_ctc = approx_ctc;
186207
}
187208

188-
return get_beam_search_result(prefixes, beam_size);
209+
return get_beam_search_result(prefixes_copy, beam_size);
210+
}
211+
212+
std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
213+
const std::vector<std::vector<double>> &probs_seq,
214+
const std::vector<std::string> &vocabulary,
215+
size_t beam_size,
216+
double cutoff_prob,
217+
size_t cutoff_top_n,
218+
size_t blank_id,
219+
int log_input,
220+
Scorer *ext_scorer)
221+
{
222+
DecoderState state(vocabulary, beam_size, cutoff_prob, cutoff_top_n, blank_id,
223+
log_input, ext_scorer);
224+
state.next(probs_seq);
225+
return state.decode();
189226
}
190227

191228

@@ -199,7 +236,8 @@ ctc_beam_search_decoder_batch(
199236
size_t cutoff_top_n,
200237
size_t blank_id,
201238
int log_input,
202-
Scorer *ext_scorer) {
239+
Scorer *ext_scorer)
240+
{
203241
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
204242
// thread pool
205243
ThreadPool pool(num_processes);

ctcdecode/src/ctc_beam_search_decoder.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,57 @@ ctc_beam_search_decoder_batch(
6464
int log_input = 0,
6565
Scorer *ext_scorer = nullptr);
6666

67+
class DecoderState
68+
{
69+
int abs_time_step;
70+
int space_id;
71+
size_t beam_size;
72+
double cutoff_prob;
73+
size_t cutoff_top_n;
74+
size_t blank_id;
75+
int log_input;
76+
std::vector<std::string> vocabulary;
77+
Scorer *ext_scorer;
78+
79+
std::vector<PathTrie*> prefixes;
80+
PathTrie root;
81+
82+
public:
83+
/* Initialize CTC beam search decoder for streaming
84+
*
85+
* Parameters:
86+
* vocabulary: A vector of vocabulary.
87+
* beam_size: The width of beam search.
88+
* cutoff_prob: Cutoff probability for pruning.
89+
* cutoff_top_n: Cutoff number for pruning.
90+
* ext_scorer: External scorer to evaluate a prefix, which consists of
91+
* n-gram language model scoring and word insertion term.
92+
* Default null, decoding the input sample without scorer.
93+
*/
94+
DecoderState(const std::vector<std::string> &vocabulary,
95+
size_t beam_size,
96+
double cutoff_prob,
97+
size_t cutoff_top_n,
98+
size_t blank_id,
99+
int log_input,
100+
Scorer *ext_scorer);
101+
~DecoderState() = default;
102+
103+
/* Process logits in decoder stream
104+
*
105+
* Parameters:
106+
* probs: 2-D vector where each element is a vector of probabilities
107+
* over alphabet of one time step.
108+
*/
109+
void next(const std::vector<std::vector<double>> &probs_seq);
110+
111+
/* Get current transcription from the decoder stream state
112+
*
113+
* Return:
114+
* A vector where each element is a pair of score and decoding result,
115+
* in descending order.
116+
*/
117+
std::vector<std::pair<double, Output>> decode() const;
118+
};
119+
67120
#endif // CTC_BEAM_SEARCH_DECODER_H_

ctcdecode/src/decoder_utils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,19 @@ bool prefix_compare(const PathTrie *x, const PathTrie *y) {
129129
}
130130
}
131131

132+
bool prefix_compare_external_scores(const PathTrie *x, const PathTrie *y,
133+
const std::unordered_map<const PathTrie*, float>& scores) {
134+
if (scores.at(x) == scores.at(y)) {
135+
if (x->character == y->character) {
136+
return false;
137+
} else {
138+
return (x->character < y->character);
139+
}
140+
} else {
141+
return scores.at(x) > scores.at(y);
142+
}
143+
}
144+
132145
void add_word_to_fst(const std::vector<int> &word,
133146
fst::StdVectorFst *dictionary) {
134147
if (dictionary->NumStates() == 0) {

ctcdecode/src/decoder_utils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#ifndef DECODER_UTILS_H_
22
#define DECODER_UTILS_H_
33

4+
#include <unordered_map>
45
#include <utility>
6+
#include <vector>
7+
58
#include "fst/log.h"
69
#include "path_trie.h"
710
#include "output.h"
@@ -62,9 +65,12 @@ std::vector<std::pair<double, Output>> get_beam_search_result(
6265
const std::vector<PathTrie *> &prefixes,
6366
size_t beam_size);
6467

65-
// Functor for prefix comparsion
68+
// Functor for prefix comparison
6669
bool prefix_compare(const PathTrie *x, const PathTrie *y);
6770

71+
bool prefix_compare_external_scores(const PathTrie *x, const PathTrie *y,
72+
const std::unordered_map<const PathTrie*, float>& scores);
73+
6874
/* Get length of utf8 encoding string
6975
* See: http://stackoverflow.com/a/4063229
7076
*/

0 commit comments

Comments
 (0)