1414
1515using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
1616
17- std::vector<std::pair<double , Output>> ctc_beam_search_decoder (
18- const std::vector<std::vector<double >> &probs_seq,
19- const std::vector<std::string> &vocabulary,
20- size_t beam_size,
21- double cutoff_prob,
22- size_t cutoff_top_n,
23- size_t blank_id,
24- int log_input,
25- Scorer *ext_scorer) {
26- // dimension check
27- size_t num_time_steps = probs_seq.size ();
28- for (size_t i = 0 ; i < num_time_steps; ++i) {
29- VALID_CHECK_EQ (probs_seq[i].size (),
30- vocabulary.size (),
31- " The shape of probs_seq does not match with "
32- " the shape of the vocabulary" );
33- }
34-
35- // assign blank id
36- // size_t blank_id = vocabulary.size();
37-
17+ DecoderState::DecoderState (const std::vector<std::string> &vocabulary,
18+ size_t beam_size,
19+ double cutoff_prob,
20+ size_t cutoff_top_n,
21+ size_t blank_id,
22+ int log_input,
23+ Scorer *ext_scorer)
24+ : abs_time_step(0 )
25+ , beam_size(beam_size)
26+ , cutoff_prob(cutoff_prob)
27+ , cutoff_top_n(cutoff_top_n)
28+ , blank_id(blank_id)
29+ , log_input(log_input)
30+ , vocabulary(vocabulary)
31+ , ext_scorer(ext_scorer)
32+ {
3833 // assign space id
3934 auto it = std::find (vocabulary.begin (), vocabulary.end (), " " );
40- int space_id = it - vocabulary.begin ();
4135 // if no space in vocabulary
42- if (( size_t )space_id >= vocabulary.size ()) {
36+ if (it == vocabulary.end ()) {
4337 space_id = -2 ;
38+ } else {
39+ space_id = std::distance (vocabulary.begin (), it);
4440 }
4541
4642 // init prefixes' root
47- PathTrie root;
4843 root.score = root.log_prob_b_prev = 0.0 ;
49- std::vector<PathTrie *> prefixes;
5044 prefixes.push_back (&root);
5145
5246 if (ext_scorer != nullptr && !ext_scorer->is_character_based ()) {
@@ -56,9 +50,22 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
5650 auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
5751 root.set_matcher (matcher);
5852 }
53+ }
54+
55+ void
56+ DecoderState::next (const std::vector<std::vector<double >> &probs_seq)
57+ {
58+ // dimension check
59+ size_t num_time_steps = probs_seq.size ();
60+ for (size_t i = 0 ; i < num_time_steps; ++i) {
61+ VALID_CHECK_EQ (probs_seq[i].size (),
62+ vocabulary.size (),
63+ " The shape of probs_seq does not match with "
64+ " the shape of the vocabulary" );
65+ }
5966
6067 // prefix search over time
61- for (size_t time_step = 0 ; time_step < num_time_steps; ++time_step) {
68+ for (size_t time_step = 0 ; time_step < num_time_steps; ++time_step, ++abs_time_step ) {
6269 auto &prob = probs_seq[time_step];
6370
6471 float min_cutoff = -NUM_FLT_INF;
@@ -97,7 +104,7 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
97104 prefix->log_prob_nb_cur , log_prob_c + prefix->log_prob_nb_prev );
98105 }
99106 // get new prefix
100- auto prefix_new = prefix->get_path_trie (c, time_step , log_prob_c);
107+ auto prefix_new = prefix->get_path_trie (c, abs_time_step , log_prob_c);
101108
102109 if (prefix_new != nullptr ) {
103110 float log_p = -NUM_FLT_INF;
@@ -147,45 +154,75 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
147154 for (size_t i = beam_size; i < prefixes.size (); ++i) {
148155 prefixes[i]->remove ();
149156 }
157+
158+ prefixes.resize (beam_size);
150159 }
151160 } // end of loop over time
161+ }
162+
163+ std::vector<std::pair<double , Output>>
164+ DecoderState::decode () const
165+ {
166+ std::vector<PathTrie*> prefixes_copy = prefixes;
167+ std::unordered_map<const PathTrie*, float > scores;
168+ for (PathTrie* prefix : prefixes_copy) {
169+ scores[prefix] = prefix->score ;
170+ }
152171
153172 // score the last word of each prefix that doesn't end with space
154173 if (ext_scorer != nullptr && !ext_scorer->is_character_based ()) {
155- for (size_t i = 0 ; i < beam_size && i < prefixes .size (); ++i) {
156- auto prefix = prefixes [i];
174+ for (size_t i = 0 ; i < beam_size && i < prefixes_copy .size (); ++i) {
175+ auto prefix = prefixes_copy [i];
157176 if (!prefix->is_empty () && prefix->character != space_id) {
158177 float score = 0.0 ;
159178 std::vector<std::string> ngram = ext_scorer->make_ngram (prefix);
160179 score = ext_scorer->get_log_cond_prob (ngram) * ext_scorer->alpha ;
161180 score += ext_scorer->beta ;
162- prefix-> score += score;
181+ scores[ prefix] += score;
163182 }
164183 }
165184 }
166185
167- size_t num_prefixes = std::min (prefixes.size (), beam_size);
168- std::sort (prefixes.begin (), prefixes.begin () + num_prefixes, prefix_compare);
186+ using namespace std ::placeholders;
187+ size_t num_prefixes = std::min (prefixes_copy.size (), beam_size);
188+ std::sort (prefixes_copy.begin (), prefixes_copy.begin () + num_prefixes,
189+ std::bind (prefix_compare_external_scores, _1, _2, scores));
169190
170191 // compute aproximate ctc score as the return score, without affecting the
171192 // return order of decoding result. To delete when decoder gets stable.
172- for (size_t i = 0 ; i < beam_size && i < prefixes .size (); ++i) {
173- double approx_ctc = prefixes[i]-> score ;
193+ for (size_t i = 0 ; i < beam_size && i < prefixes_copy .size (); ++i) {
194+ double approx_ctc = scores[prefixes_copy[i]] ;
174195 if (ext_scorer != nullptr ) {
175196 std::vector<int > output;
176197 std::vector<int > timesteps;
177- prefixes [i]->get_path_vec (output, timesteps);
198+ prefixes_copy [i]->get_path_vec (output, timesteps);
178199 auto prefix_length = output.size ();
179200 auto words = ext_scorer->split_labels (output);
180201 // remove word insert
181202 approx_ctc = approx_ctc - prefix_length * ext_scorer->beta ;
182203 // remove language model weight:
183204 approx_ctc -= (ext_scorer->get_sent_log_prob (words)) * ext_scorer->alpha ;
184205 }
185- prefixes [i]->approx_ctc = approx_ctc;
206+ prefixes_copy [i]->approx_ctc = approx_ctc;
186207 }
187208
188- return get_beam_search_result (prefixes, beam_size);
209+ return get_beam_search_result (prefixes_copy, beam_size);
210+ }
211+
212+ std::vector<std::pair<double , Output>> ctc_beam_search_decoder (
213+ const std::vector<std::vector<double >> &probs_seq,
214+ const std::vector<std::string> &vocabulary,
215+ size_t beam_size,
216+ double cutoff_prob,
217+ size_t cutoff_top_n,
218+ size_t blank_id,
219+ int log_input,
220+ Scorer *ext_scorer)
221+ {
222+ DecoderState state (vocabulary, beam_size, cutoff_prob, cutoff_top_n, blank_id,
223+ log_input, ext_scorer);
224+ state.next (probs_seq);
225+ return state.decode ();
189226}
190227
191228
@@ -199,7 +236,8 @@ ctc_beam_search_decoder_batch(
199236 size_t cutoff_top_n,
200237 size_t blank_id,
201238 int log_input,
202- Scorer *ext_scorer) {
239+ Scorer *ext_scorer)
240+ {
203241 VALID_CHECK_GT (num_processes, 0 , " num_processes must be nonnegative!" );
204242 // thread pool
205243 ThreadPool pool (num_processes);
0 commit comments