Skip to content

Commit b6abf43

Browse files
author
Hang Lyu
committed
first_nonempty_bucket_index_ and first_nonempty_bucket_
1 parent b636119 commit b6abf43

File tree

4 files changed

+186
-186
lines changed

4 files changed

+186
-186
lines changed

src/decoder/lattice-faster-decoder-combine-bucketqueue.cc

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,82 +27,87 @@
2727
namespace kaldi {
2828

2929
template<typename Token>
30-
BucketQueue<Token>::BucketQueue(BaseFloat best_cost_estimate,
31-
BaseFloat cost_scale) :
30+
BucketQueue<Token>::BucketQueue(BaseFloat cost_scale) :
3231
cost_scale_(cost_scale) {
3332
// NOTE: we reserve plenty of elements to avoid expensive reallocations
3433
// later on. Normally, the size is a little bigger than (adaptive_beam +
35-
// 5) * cost_scale.
34+
// 15) * cost_scale.
3635
int32 bucket_size = 100;
3736
buckets_.resize(bucket_size);
38-
bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale_);
39-
first_occupied_vec_index_ = bucket_size;
37+
bucket_offset_ = 15 * cost_scale_;
38+
first_nonempty_bucket_index_ = bucket_size - 1;
39+
first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_];
4040
}
4141

4242
template<typename Token>
4343
void BucketQueue<Token>::Push(Token *tok) {
44-
int32 bucket_index = std::floor(tok->tot_cost * cost_scale_);
45-
size_t vec_index = static_cast<size_t>(bucket_index - bucket_storage_begin_);
46-
if (vec_index >= buckets_.size()) {
44+
size_t bucket_index = std::floor(tok->tot_cost * cost_scale_) +
45+
bucket_offset_;
46+
if (bucket_index >= buckets_.size()) {
4747
int32 margin = 10; // a margin which is used to reduce re-allocate
4848
// space frequently
49-
// A cast from unsigned to signed type does not generate a machine-code
50-
// instruction
51-
if (static_cast<int32>(vec_index) > 0) {
52-
KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve"
53-
<< " more elements in constructor. Push back.";
54-
buckets_.resize(static_cast<int32>(vec_index) + margin);
49+
if (static_cast<int32>(bucket_index) > 0) {
50+
buckets_.resize(bucket_index + margin);
5551
} else { // less than 0
56-
KALDI_WARN << "Have to reallocate the BucketQueue. Maybe need to reserve"
57-
<< " more elements in constructor. Push front.";
58-
int32 increase_size = - static_cast<int32>(vec_index) + margin;
59-
buckets_.resize(buckets_.size() + increase_size);
60-
// translation
61-
for (size_t i = buckets_.size() - 1; i >= increase_size; i--) {
62-
buckets_[i].swap(buckets_[i - increase_size]);
63-
}
64-
bucket_storage_begin_ = bucket_storage_begin_ - increase_size;
65-
vec_index = static_cast<int32>(vec_index) + increase_size;
66-
first_occupied_vec_index_ = vec_index;
52+
int32 increase_size = - static_cast<int32>(bucket_index) + margin;
53+
buckets_.resize(buckets_.size() + increase_size);
54+
// translation
55+
for (size_t i = buckets_.size() - 1; i >= increase_size; i--) {
56+
buckets_[i].swap(buckets_[i - increase_size]);
57+
}
58+
bucket_offset_ = bucket_offset_ + increase_size * cost_scale_;
59+
bucket_index += increase_size;
60+
first_nonempty_bucket_index_ = bucket_index;
61+
first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_];
6762
}
6863
}
6964
tok->in_queue = true;
70-
buckets_[vec_index].push_back(tok);
71-
if (vec_index < first_occupied_vec_index_)
72-
first_occupied_vec_index_ = vec_index;
65+
buckets_[bucket_index].push_back(tok);
66+
if (bucket_index < first_nonempty_bucket_index_) {
67+
first_nonempty_bucket_index_ = bucket_index;
68+
first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_];
69+
}
7370
}
7471

7572
template<typename Token>
7673
Token* BucketQueue<Token>::Pop() {
77-
int32 vec_index = first_occupied_vec_index_;
78-
while (vec_index < buckets_.size()) {
79-
Token* tok = buckets_[vec_index].back();
80-
// Remove the best token
81-
buckets_[vec_index].pop_back();
82-
83-
if (buckets_[vec_index].empty()) { // This bucket is empty. Update vec_index
84-
int32 next_vec_index = vec_index + 1;
85-
for (; next_vec_index < buckets_.size(); next_vec_index++) {
86-
if (!buckets_[next_vec_index].empty()) break;
74+
while (true) {
75+
if (!first_nonempty_bucket_->empty()) {
76+
Token *ans = first_nonempty_bucket_->back();
77+
first_nonempty_bucket_->pop_back();
78+
if (ans->in_queue) {
79+
ans->in_queue = false;
80+
return ans;
8781
}
88-
vec_index = next_vec_index;
89-
first_occupied_vec_index_ = vec_index;
9082
}
83+
if (first_nonempty_bucket_->empty()) {
84+
// In case, pop an empty BucketQueue
85+
if (first_nonempty_bucket_index_ == buckets_.size() - 1) {
86+
return NULL;
87+
}
9188

92-
if (tok->in_queue) { // This is a effective token
93-
tok->in_queue = false;
94-
return tok;
89+
first_nonempty_bucket_index_++;
90+
for (; first_nonempty_bucket_index_ < buckets_.size() - 1;
91+
first_nonempty_bucket_index_++) {
92+
if (!buckets_[first_nonempty_bucket_index_].empty())
93+
break;
94+
}
95+
first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_];
96+
if (first_nonempty_bucket_index_ == buckets_.size() - 1 &&
97+
first_nonempty_bucket_->empty()) {
98+
return NULL;
99+
}
95100
}
96101
}
97-
return NULL;
98102
}
99103

100104
template<typename Token>
101105
void BucketQueue<Token>::Clear() {
102-
for (size_t i = 0; i < buckets_.size(); i++) {
106+
for (size_t i = first_nonempty_bucket_index_; i < buckets_.size(); i++) {
103107
buckets_[i].clear();
104108
}
105-
first_occupied_vec_index_ = buckets_.size();
109+
first_nonempty_bucket_index_ = buckets_.size() - 1;
110+
first_nonempty_bucket_ = &buckets_[first_nonempty_bucket_index_];
106111
}
107112

108113
// instantiate this class once for each thing you have to decode.
@@ -111,7 +116,7 @@ LatticeFasterDecoderCombineTpl<FST, Token>::LatticeFasterDecoderCombineTpl(
111116
const FST &fst,
112117
const LatticeFasterDecoderCombineConfig &config):
113118
fst_(&fst), delete_fst_(false), config_(config), num_toks_(0),
114-
cur_queue_(0, config_.cost_scale) {
119+
cur_queue_(config_.cost_scale) {
115120
config.Check();
116121
prev_toks_.reserve(1000);
117122
cur_toks_.reserve(1000);
@@ -122,7 +127,7 @@ template <typename FST, typename Token>
122127
LatticeFasterDecoderCombineTpl<FST, Token>::LatticeFasterDecoderCombineTpl(
123128
const LatticeFasterDecoderCombineConfig &config, FST *fst):
124129
fst_(fst), delete_fst_(true), config_(config), num_toks_(0),
125-
cur_queue_(0, config_.cost_scale) {
130+
cur_queue_(config_.cost_scale) {
126131
config.Check();
127132
prev_toks_.reserve(1000);
128133
cur_toks_.reserve(1000);
@@ -133,8 +138,6 @@ template <typename FST, typename Token>
133138
LatticeFasterDecoderCombineTpl<FST, Token>::~LatticeFasterDecoderCombineTpl() {
134139
ClearActiveTokens();
135140
if (delete_fst_) delete fst_;
136-
//prev_toks_.clear();
137-
//cur_toks_.clear();
138141
}
139142

140143
template <typename FST, typename Token>
@@ -819,7 +822,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
819822
BaseFloat cur_cost = tok->tot_cost;
820823
StateId state = tok->state_id;
821824
if (cur_cost > cur_cutoff &&
822-
num_toks_processed < config_.min_active) { // Don't bother processing
825+
num_toks_processed > config_.min_active) { // Don't bother processing
823826
// successors.
824827
break; // This is a priority queue. The following tokens will be worse
825828
} else if (cur_cost + adaptive_beam < cur_cutoff) {
@@ -848,7 +851,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
848851

849852
// "changed" tells us whether the new token has a different
850853
// cost from before, or is new.
851-
if (changed && !new_tok->in_queue) {
854+
if (changed) {
852855
cur_queue_.Push(new_tok);
853856
}
854857
}
@@ -948,7 +951,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(
948951
BaseFloat cur_cost = tok->tot_cost;
949952
StateId state = tok->state_id;
950953
if (cur_cost > cur_cutoff &&
951-
num_toks_processed < config_.min_active) { // Don't bother processing
954+
num_toks_processed > config_.min_active) { // Don't bother processing
952955
// successors.
953956
break; // This is a priority queue. The following tokens will be worse
954957
} else if (cur_cost + adaptive_beam < cur_cutoff) {
@@ -977,7 +980,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(
977980

978981
// "changed" tells us whether the new token has a different
979982
// cost from before, or is new.
980-
if (changed && !new_tok->in_queue) {
983+
if (changed) {
981984
cur_queue_.Push(new_tok);
982985
}
983986
}

src/decoder/lattice-faster-decoder-combine-bucketqueue.h

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -240,33 +240,29 @@ struct BackpointerToken {
240240
template<typename Token>
241241
class BucketQueue {
242242
public:
243-
/** Constructor. 'cost_scale' is a scale that we multiply the token costs by
244-
* before intergerizing; a larger value means more buckets.
245-
* 'best_cost_estimate' is an estimate of the best (lowest) cost that
246-
* we are likely to encounter (e.g. the best cost that we have seen so far).
247-
* It is used to initialize 'bucket_storage_begin_'.
248-
*/
249-
BucketQueue(BaseFloat best_cost_estimate, BaseFloat cost_scale = 1.0);
250-
251-
// Add a Token to the queue; sets the field tok->in_queue to true (it is not
243+
// Constructor. 'cost_scale' is a scale that we multiply the token costs by
244+
// before intergerizing; a larger value means more buckets.
245+
// 'bucket_offset_' is initialized to "15 * cost_scale_". It is an empirical
246+
// value in case we trigger the re-allocation in normal case, since we do in
247+
// fact normalize costs to be not far from zero on each frame.
248+
BucketQueue(BaseFloat cost_scale = 1.0);
249+
250+
// Adds Token to the queue; sets the field tok->in_queue to true (it is not
252251
// an error if it was already true).
253252
// If a Token was already in the queue but its cost improves, you should
254253
// just Push it again. It will be added to (possibly) a different bucket, but
255-
// the old entry will remain. The old entry in the queue will be considered as
256-
// nonexistent when we try to pop it and notice that the recorded cost
257-
// does not match the cost in the Token. (Actually, we use in_queue to decide
258-
// an entry is nonexistent or This strategy means that you may not
259-
// delete Tokens as long as pointers to them might exist in this queue (hence,
260-
// it is probably best to only ever have this queue as a local variable inside
261-
// a function).
254+
// the old entry will remain. We use "tok->in_queue" to decide
255+
// an entry is nonexistent or not. When pop a Token off, the field
256+
// 'tok->in_queue' is set to false. So the old entry in the queue will be
257+
// considered as nonexistent when we try to pop it.
262258
void Push(Token *tok);
263259

264260
// Removes and returns the next Token 'tok' in the queue, or NULL if there
265261
// were no Tokens left. Sets tok->in_queue to false for the returned Token.
266262
Token* Pop();
267263

268-
// Clear all the individual buckets. Set 'first_occupied_vec_index_' to the
269-
// value past the end of buckets_.
264+
// Clears all the individual buckets. Sets 'first_nonempty_bucket_index_' to
265+
// the end of buckets_.
270266
void Clear();
271267

272268
private:
@@ -283,21 +279,20 @@ class BucketQueue {
283279
// then access buckets_[vec_index].
284280
std::vector<std::vector<Token*> > buckets_;
285281

286-
// The lowest-numbered vec_index that is occupied (i.e. the first one which
287-
// has any elements). Will be updated as we add or remove tokens.
288-
// If this corresponds to a value past the end of buckets_, we interpret it
289-
// as 'there are no buckets with entries'.
290-
int32 first_occupied_vec_index_;
291-
292282
// An offset that determines how we index into the buckets_ vector;
293-
// may be interpreted as a 'bucket_index' that is better than any one that
294-
// we are likely to see.
295283
// In the constructor this will be initialized to something like
296-
// bucket_storage_begin_ = std::floor((best_cost_estimate - 15) * cost_scale)
297-
// which will make it unlikely that we have to change this value in future if
298-
// we get a much better Token (this is expensive because it involves
299-
// reallocating 'buckets_').
300-
int32 bucket_storage_begin_;
284+
// "15 * cost_scale_" which will make it unlikely that we have to change this
285+
// value in future if we get a much better Token (this is expensive because it
286+
// involves reallocating 'buckets_').
287+
int32 bucket_offset_;
288+
289+
// first_nonempty_bucket_index_ is an integer in the range [0,
290+
// buckets_.size() - 1] which is not larger than the index of the first
291+
// nonempty element of buckets_.
292+
int32 first_nonempty_bucket_index_;
293+
294+
// Synchronizes with first_nonempty_bucket_index_.
295+
std::vector<Token*> *first_nonempty_bucket_;
301296
};
302297

303298
/** This is the "normal" lattice-generating decoder.
@@ -543,14 +538,16 @@ class LatticeFasterDecoderCombineTpl {
543538
void ProcessForFrame(DecodableInterface *decodable);
544539

545540
/// Processes nonemitting (epsilon) arcs for one frame.
546-
/// Calls this function once when all frames were processed.
547-
/// Or calls it in GetRawLattice() to generate the complete token list for
548-
/// the last frame. [Deal With the tokens in map "cur_toks_" which would
549-
/// only contains emittion tokens from previous frame.]
550-
/// If the map, "token_orig_cost", isn't NULL, we build the map which will
551-
/// be used to recover "active_toks_[last_frame]" token list for the last
552-
/// frame.
553-
void ProcessNonemitting(std::unordered_map<Token*, BaseFloat> *token_orig_cost);
541+
/// This function is called from FinalizeDecoding(), and also from
542+
/// GetRawLattice() if GetRawLattice() is called before FinalizeDecoding() is
543+
/// called. In the latter case, RecoverLastTokenList() is called later by
544+
/// GetRawLattice() to restore the state prior to ProcessNonemitting() being
545+
/// called, since ProcessForFrame() does not expect nonemitting arcs to
546+
/// already have been propagagted. ["token_orig_cost" isn't NULL in the
547+
/// latter case, we build the map which will be used to recover
548+
/// "active_toks_[last_frame]" token list for the last frame.]
549+
void ProcessNonemitting(
550+
std::unordered_map<Token*, BaseFloat> *token_orig_cost);
554551

555552
/// When GetRawLattice() is called during decoding, the
556553
/// active_toks_[last_frame] is changed. To keep the consistency of function

0 commit comments

Comments
 (0)