@@ -44,6 +44,8 @@ CudaOnlinePipelineDynamicBatcher::CudaOnlinePipelineDynamicBatcher(
4444
4545 batcher_thread_.reset (new std::thread (
4646 &CudaOnlinePipelineDynamicBatcher::BatcherThreadLoop, this ));
47+
48+ n_chunks_per_corr_.reserve (num_channels_);
4749}
4850
4951CudaOnlinePipelineDynamicBatcher::~CudaOnlinePipelineDynamicBatcher () {
@@ -64,6 +66,7 @@ void CudaOnlinePipelineDynamicBatcher::Push(
6466 backlog_.push_back (
6567 {corr_id, is_first_chunk, is_last_chunk, std::move (wave_samples)});
6668 }
69+ ++n_chunks_per_corr_[corr_id];
6770 n_chunks_not_done_.fetch_add (1 , std::memory_order_release);
6871}
6972
@@ -144,9 +147,20 @@ void CudaOnlinePipelineDynamicBatcher::BatcherThreadLoop() {
144147 curr_batch_->corr_ids , curr_batch_->h_all_waveform ,
145148 curr_batch_->n_samples_valid , curr_batch_->is_first_chunk ,
146149 curr_batch_->is_last_chunk );
147- n_chunks_not_done_.fetch_sub (curr_batch_->Size (),
148- std::memory_order_release);
149150
151+ {
152+ // Update counts
153+ std::lock_guard<std::mutex> lk (next_batch_and_backlog_m_);
154+ n_chunks_not_done_.fetch_sub (curr_batch_->Size (),
155+ std::memory_order_release);
156+ for (size_t i = 0 ; i < curr_batch_->corr_ids .size (); ++i) {
157+ CorrelationID corr_id = curr_batch_->corr_ids [i];
158+ --n_chunks_per_corr_[corr_id];
159+ if (curr_batch_->is_last_chunk [i]) {
160+ n_chunks_per_corr_.erase (corr_id);
161+ }
162+ }
163+ }
150164 curr_batch_->Clear ();
151165 }
152166
@@ -167,5 +181,12 @@ void CudaOnlinePipelineDynamicBatcher::WaitForCompletion() {
167181 cuda_pipeline_.WaitForLatticeCallbacks ();
168182}
169183
184+ int CudaOnlinePipelineDynamicBatcher::GetNumPendingChunks (CorrelationID corr_id) {
185+ std::lock_guard<std::mutex> lk (next_batch_and_backlog_m_);
186+ if (n_chunks_per_corr_.find (corr_id) == n_chunks_per_corr_.end ())
187+ return 0 ;
188+ return n_chunks_per_corr_[corr_id];
189+ }
190+
170191} // namespace cuda_decoder
171192} // namespace kaldi
0 commit comments