Skip to content

Commit 226a698

Browse files
author
Hang Lyu
committed
bucketqueue without GetCutoff
1 parent a4a2ddc commit 226a698

File tree

3 files changed

+120
-148
lines changed

3 files changed

+120
-148
lines changed

src/decoder/lattice-faster-decoder-combine-bucketqueue.cc

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,28 @@ BucketQueue<Token>::BucketQueue(BaseFloat best_cost_estimate,
4242
template<typename Token>
4343
void BucketQueue<Token>::Push(Token *tok) {
4444
int32 bucket_index = std::floor(tok->tot_cost * cost_scale_);
45-
int32 vec_index = bucket_index - bucket_storage_begin_;
45+
size_t vec_index = static_cast<size_t>(bucket_index - bucket_storage_begin_);
4646

47-
if (vec_index < 0) {
47+
if (vec_index >= buckets_.size()) {
4848
KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve"
49-
<< " more elements in constructor. Push front.";
50-
int32 increase_size = - vec_index;
51-
std::vector<std::vector<Token*> > tmp(buckets_);
52-
buckets_.resize(tmp.size() + increase_size);
53-
std::copy(tmp.begin(), tmp.end(), buckets_.begin() + increase_size);
54-
// Update start point
55-
bucket_storage_begin_ = bucket_index;
56-
vec_index = 0;
57-
} else if (vec_index > buckets_.size() - 1) {
58-
KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve"
59-
<< " more elements in constructor. Push back.";
60-
buckets_.resize(vec_index + 1);
49+
<< " more elements in constructor.";
50+
int32 offset = static_cast<int32>(vec_index);
51+
// a margin here (e.g. 10);
52+
int32 increase_size = offset >= 0 ? offset + 1 - buckets_.size() + 10 :
53+
- offset + 10;
54+
buckets_.resize(buckets_.size() + increase_size);
55+
56+
// Push front
57+
if (offset < 0) {
58+
std::vector<std::vector<Token*> > tmp(buckets_);
59+
buckets_.clear();
60+
for (int32 i = 10 - offset ; i < buckets_.size(); i++) {
61+
buckets_[i].swap(tmp[i + offset - 10]);
62+
}
63+
// Update start point
64+
bucket_storage_begin_ = bucket_index - 10;
65+
vec_index = 10;
66+
}
6167
}
6268

6369
tok->in_queue = true;

src/decoder/lattice-faster-decoder-combine.cc

Lines changed: 92 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,28 @@ BucketQueue<Token>::BucketQueue(BaseFloat best_cost_estimate,
4242
template<typename Token>
4343
void BucketQueue<Token>::Push(Token *tok) {
4444
int32 bucket_index = std::floor(tok->tot_cost * cost_scale_);
45-
int32 vec_index = bucket_index - bucket_storage_begin_;
45+
size_t vec_index = static_cast<size_t>(bucket_index - bucket_storage_begin_);
4646

47-
if (vec_index < 0) {
47+
if (vec_index >= buckets_.size()) {
4848
KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve"
49-
<< " more elements in constructor. Push front.";
50-
int32 increase_size = - vec_index;
51-
std::vector<std::vector<Token*> > tmp(buckets_);
52-
buckets_.resize(tmp.size() + increase_size);
53-
std::copy(tmp.begin(), tmp.end(), buckets_.begin() + increase_size);
54-
// Update start point
55-
bucket_storage_begin_ = bucket_index;
56-
vec_index = 0;
57-
} else if (vec_index > buckets_.size() - 1) {
58-
KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve"
59-
<< " more elements in constructor. Push back.";
60-
buckets_.resize(vec_index + 1);
49+
<< " more elements in constructor.";
50+
int32 offset = static_cast<int32>(vec_index);
51+
// a margin here (e.g. 10);
52+
int32 increase_size = offset >= 0 ? offset + 1 - buckets_.size() + 10 :
53+
- offset + 10;
54+
buckets_.resize(buckets_.size() + increase_size);
55+
56+
// Push front
57+
if (offset < 0) {
58+
std::vector<std::vector<Token*> > tmp(buckets_);
59+
buckets_.clear();
60+
for (int32 i = 10 - offset ; i < buckets_.size(); i++) {
61+
buckets_[i].swap(tmp[i + offset - 10]);
62+
}
63+
// Update start point
64+
bucket_storage_begin_ = bucket_index - 10;
65+
vec_index = 10;
66+
}
6167
}
6268

6369
tok->in_queue = true;
@@ -143,6 +149,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::InitDecoding() {
143149
cur_toks_[start_state] = start_tok; // initialize current tokens map
144150
num_toks_++;
145151
best_token_in_next_frame_ = start_tok;
152+
adaptive_beam_ = config_.beam;
146153
}
147154

148155
// Returns true if any kind of traceback is available (not necessarily from
@@ -753,67 +760,6 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::FinalizeDecoding() {
753760
<< " to " << num_toks_;
754761
}
755762

756-
/// Gets the weight cutoff.
757-
template <typename FST, typename Token>
758-
BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
759-
const TokenList &token_list, const Token* best_token,
760-
BaseFloat *adaptive_beam, BucketQueue *queue) {
761-
BaseFloat best_weight = best_token->tot_cost;
762-
// positive == high cost == bad.
763-
// best_weight is the minimum value.
764-
if (config_.max_active == std::numeric_limits<int32>::max() &&
765-
config_.min_active == 0) {
766-
for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) {
767-
queue->Push(tok);
768-
}
769-
if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
770-
return best_weight + config_.beam;
771-
} else {
772-
tmp_array_.clear();
773-
for (Token* tok = token_list.toks; tok != NULL; tok = tok->next) {
774-
BaseFloat w = static_cast<BaseFloat>(tok->tot_cost);
775-
tmp_array_.push_back(w);
776-
queue->Push(tok);
777-
}
778-
779-
BaseFloat beam_cutoff = best_weight + config_.beam,
780-
min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
781-
max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
782-
783-
KALDI_VLOG(6) << "Number of emitting tokens on frame "
784-
<< NumFramesDecoded() - 1 << " is " << tmp_array_.size();
785-
786-
if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
787-
std::nth_element(tmp_array_.begin(),
788-
tmp_array_.begin() + config_.max_active,
789-
tmp_array_.end());
790-
max_active_cutoff = tmp_array_[config_.max_active];
791-
}
792-
if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
793-
if (adaptive_beam)
794-
*adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta;
795-
return max_active_cutoff;
796-
}
797-
if (tmp_array_.size() > static_cast<size_t>(config_.min_active)) {
798-
if (config_.min_active == 0) min_active_cutoff = best_weight;
799-
else {
800-
std::nth_element(tmp_array_.begin(),
801-
tmp_array_.begin() + config_.min_active,
802-
tmp_array_.size() > static_cast<size_t>(config_.max_active) ?
803-
tmp_array_.begin() + config_.max_active : tmp_array_.end());
804-
min_active_cutoff = tmp_array_[config_.min_active];
805-
}
806-
}
807-
if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
808-
if (adaptive_beam)
809-
*adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta;
810-
return min_active_cutoff;
811-
} else {
812-
*adaptive_beam = config_.beam;
813-
return beam_cutoff;
814-
}
815-
}
816-
}
817763

818764
template <typename FST, typename Token>
819765
void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
@@ -834,51 +780,27 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
834780
}
835781

836782
KALDI_ASSERT(best_token_in_next_frame_);
837-
BucketQueue cur_queue(best_token_in_next_frame_->tot_cost);
838-
BaseFloat adaptive_beam;
839-
// "cur_cutoff" is used to constrain the epsilon emittion in current frame.
840-
// It will not be updated.
841-
BaseFloat cur_cutoff = GetCutoff(active_toks_[frame],
842-
best_token_in_next_frame_,
843-
&adaptive_beam, &cur_queue);
844-
KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is "
845-
<< adaptive_beam;
846-
847-
// pruning "online" before having seen all tokens
783+
BucketQueue cur_queue(best_token_in_next_frame_->tot_cost, config_.cost_scale);
784+
// Add tokens to queue
785+
for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
786+
cur_queue.Push(tok);
787+
}
848788

789+
// Declare a local variable so the compiler can put it in a register, since
790+
// C++ assumes other threads could be modifying class members.
791+
BaseFloat adaptive_beam = adaptive_beam_;
792+
// "cur_cutoff" will be kept to the best-seen-so-far token on this frame
793+
// + adaptive_beam
794+
BaseFloat cur_cutoff = std::numeric_limits<BaseFloat>::infinity();
849795
// "next_cutoff" is used to limit a new token in next frame should be handle
850796
// or not. It will be updated along with the further processing.
797+
// this will be kept updated to the best-seen-so-far token "on next frame"
798+
// + adaptive_beam
851799
BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
852800
// "cost_offset" contains the acoustic log-likelihoods on current frame in
853801
// order to keep everything in a nice dynamic range. Reduce roundoff errors.
854-
BaseFloat cost_offset = 0.0;
855-
856-
// First process the best token to get a hopefully
857-
// reasonably tight bound on the next cutoff. The only
858-
// products of the next block are "next_cutoff" and "cost_offset".
859-
// Notice: As the difference between the combine version and the traditional
860-
// version, this "best_tok" is choosen from emittion tokens. Normally, the
861-
// best token of one frame comes from an epsilon non-emittion. So the best
862-
// token is a looser boundary. We use it to estimate a bound on the next
863-
// cutoff and we will update the "next_cutoff" once we have better tokens.
864-
// The "next_cutoff" will be updated in further processing.
865-
Token *best_tok = best_token_in_next_frame_;
866-
StateId best_tok_state_id = best_tok->state_id;
867-
if (best_tok) {
868-
cost_offset = - best_tok->tot_cost;
869-
for (fst::ArcIterator<FST> aiter(*fst_, best_tok_state_id);
870-
!aiter.Done();
871-
aiter.Next()) {
872-
const Arc &arc = aiter.Value();
873-
if (arc.ilabel != 0) { // propagate..
874-
// ac_cost + graph_cost
875-
BaseFloat new_weight = arc.weight.Value() + cost_offset -
876-
decodable->LogLikelihood(frame, arc.ilabel) + best_tok->tot_cost;
877-
if (new_weight + adaptive_beam < next_cutoff)
878-
next_cutoff = new_weight + adaptive_beam;
879-
}
880-
}
881-
}
802+
BaseFloat cost_offset = - best_token_in_next_frame_->tot_cost;
803+
882804
best_token_in_next_frame_ = NULL;
883805
// Store the offset on the acoustic likelihoods that we're applying.
884806
// Could just do cost_offsets_.push_back(cost_offset), but we
@@ -888,11 +810,17 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
888810

889811
// Iterator the "cur_queue_" to process non-emittion and emittion arcs in fst.
890812
Token *tok = NULL;
891-
while ((tok = cur_queue.Pop()) != NULL) {
813+
int32 num_toks_processed = 0;
814+
int32 max_active = config_.max_active;
815+
for (; num_toks_processed < max_active && (tok = cur_queue.Pop()) != NULL;
816+
num_toks_processed++) {
892817
BaseFloat cur_cost = tok->tot_cost;
893818
StateId state = tok->state_id;
894-
if (cur_cost > cur_cutoff) // Don't bother processing successors.
895-
continue;
819+
if (cur_cost > cur_cutoff) { // Don't bother processing successors.
820+
break; // This is a priority queue. The following tokens will be worse
821+
} else if (cur_cost + adaptive_beam < cur_cutoff) {
822+
cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary
823+
}
896824
// If "tok" has any existing forward links, delete them,
897825
// because we're about to regenerate them. This is a kind
898826
// of non-optimality (remember, this is the simple decoder),
@@ -945,8 +873,32 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
945873
}
946874
} // for all arcs
947875
} // end of while loop
948-
//KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1
949-
// << " is " << prev_toks_.size();
876+
877+
{ // This block updates adaptive_beam_
878+
BaseFloat beam_used_this_frame = adaptive_beam;
879+
Token *tok = cur_queue.Pop();
880+
if (tok != NULL) {
881+
// The queue would only be nonempty if we hit the max-active constraint.
882+
BaseFloat best_cost_this_frame = cur_cutoff - adaptive_beam;
883+
beam_used_this_frame = tok->tot_cost - best_cost_this_frame;
884+
}
885+
if (num_toks_processed <= config_.min_active) {
886+
// num-toks active is dangerously low, increase the beam even if it
887+
// already exceeds the user-specified beam.
888+
adaptive_beam_ = std::max<BaseFloat>(
889+
config_.beam, beam_used_this_frame + 2.0 * config_.beam_delta);
890+
} else {
891+
// have adaptive_beam_ approach beam_ in intervals of config_.beam_delta
892+
BaseFloat diff_from_beam = beam_used_this_frame - config_.beam;
893+
if (std::abs(diff_from_beam) < config_.beam_delta) {
894+
adaptive_beam_ = config_.beam;
895+
} else {
896+
// make it close to beam_
897+
adaptive_beam_ = beam_used_this_frame -
898+
config_.beam_delta * (diff_from_beam > 0 ? 1 : -1);
899+
}
900+
}
901+
}
950902
}
951903

952904

@@ -969,20 +921,31 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(
969921
tmp_toks = &cur_toks_;
970922
}
971923

972-
BucketQueue cur_queue(best_token_in_next_frame_->tot_cost);
973-
// "cur_cutoff" is used to constrain the epsilon emittion in current frame.
974-
// It will not be updated.
975-
BaseFloat adaptive_beam;
976-
BaseFloat cur_cutoff = GetCutoff(active_toks_[frame],
977-
best_token_in_next_frame_,
978-
&adaptive_beam, &cur_queue);
924+
BucketQueue cur_queue(best_token_in_next_frame_->tot_cost, config_.cost_scale);
925+
for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
926+
cur_queue.Push(tok);
927+
}
928+
929+
// Declare a local variable so the compiler can put it in a register, since
930+
// C++ assumes other threads could be modifying class members.
931+
BaseFloat adaptive_beam = adaptive_beam_;
932+
// "cur_cutoff" will be kept to the best-seen-so-far token on this frame
933+
// + adaptive_beam
934+
BaseFloat cur_cutoff = std::numeric_limits<BaseFloat>::infinity();
979935

980936
Token *tok = NULL;
981-
while ((tok = cur_queue.Pop()) != NULL) {
937+
int32 num_toks_processed = 0;
938+
int32 max_active = config_.max_active;
939+
940+
for (; num_toks_processed < max_active && (tok = cur_queue.Pop()) != NULL;
941+
num_toks_processed++) {
982942
BaseFloat cur_cost = tok->tot_cost;
983943
StateId state = tok->state_id;
984-
if (cur_cost > cur_cutoff) // Don't bother processing successors.
985-
continue;
944+
if (cur_cost > cur_cutoff) { // Don't bother processing successors.
945+
break; // This is a priority queue. The following tokens will be worse
946+
} else if (cur_cost + adaptive_beam < cur_cutoff) {
947+
cur_cutoff = cur_cost + adaptive_beam; // a tighter boundary
948+
}
986949
// If "tok" has any existing forward links, delete them,
987950
// because we're about to regenerate them. This is a kind
988951
// of non-optimality (remember, this is the simple decoder),

src/decoder/lattice-faster-decoder-combine.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ struct LatticeFasterDecoderCombineConfig {
4646
// command-line program.
4747
BaseFloat beam_delta; // has nothing to do with beam_ratio
4848
BaseFloat hash_ratio;
49+
BaseFloat cost_scale;
4950
BaseFloat prune_scale; // Note: we don't make this configurable on the command line,
5051
// it's not a very important parameter. It affects the
5152
// algorithm that prunes the tokens as we go.
@@ -62,6 +63,7 @@ struct LatticeFasterDecoderCombineConfig {
6263
determinize_lattice(true),
6364
beam_delta(0.5),
6465
hash_ratio(2.0),
66+
cost_scale(1.0),
6567
prune_scale(0.1) { }
6668
void Register(OptionsItf *opts) {
6769
det_opts.Register(opts);
@@ -81,6 +83,10 @@ struct LatticeFasterDecoderCombineConfig {
8183
"max-active constraint is applied. Larger is more accurate.");
8284
opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to "
8385
"control hash behavior");
86+
opts->Register("cost-scale", &cost_scale, "A scale that we multiply the "
87+
"token costs by before intergerizing; a larger value means "
88+
"more buckets and precise.");
89+
8490
}
8591
void Check() const {
8692
KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0
@@ -570,16 +576,11 @@ class LatticeFasterDecoderCombineTpl {
570576
/// on a complete token list on one frame. But, in this version, it is used
571577
/// on a token list which only contains the emittion part. So the max_active
572578
/// and min_active values might be narrowed.
573-
BaseFloat GetCutoff(const TokenList &token_list, const Token* best_token,
574-
BaseFloat *adaptive_beam,
575-
BucketQueue *queue);
576-
577579
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
578580
// frame (members of TokenList are toks, must_prune_forward_links,
579581
// must_prune_tokens).
580582
std::queue<StateId> cur_queue_; // temp variable used in ProcessForFrame
581583
// and ProcessNonemitting
582-
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
583584
// Stores the best token in next frame. The tot_cost of it will be used to
584585
// initialize the BucketQueue.
585586
Token* best_token_in_next_frame_;
@@ -614,6 +615,8 @@ class LatticeFasterDecoderCombineTpl {
614615
BaseFloat final_relative_cost_;
615616
BaseFloat final_best_cost_;
616617

618+
BaseFloat adaptive_beam_; // will be set to beam_ when we start
619+
617620
// This function takes a singly linked list of tokens for a single frame, and
618621
// outputs a list of them in topological order (it will crash if no such order
619622
// can be found, which will typically be due to decoding graphs with epsilon

0 commit comments

Comments
 (0)