Skip to content

Commit 7460d99

Browse files
authored
[cudadecoder] Expose API to wait for separate streams (#4681)
1 parent df1e911 commit 7460d99

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/cudadecoder/cuda-online-pipeline-dynamic-batcher.cc

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ CudaOnlinePipelineDynamicBatcher::CudaOnlinePipelineDynamicBatcher(
4444

4545
batcher_thread_.reset(new std::thread(
4646
&CudaOnlinePipelineDynamicBatcher::BatcherThreadLoop, this));
47+
48+
n_chunks_per_corr_.reserve(num_channels_);
4749
}
4850

4951
CudaOnlinePipelineDynamicBatcher::~CudaOnlinePipelineDynamicBatcher() {
@@ -64,6 +66,7 @@ void CudaOnlinePipelineDynamicBatcher::Push(
6466
backlog_.push_back(
6567
{corr_id, is_first_chunk, is_last_chunk, std::move(wave_samples)});
6668
}
69+
++n_chunks_per_corr_[corr_id];
6770
n_chunks_not_done_.fetch_add(1, std::memory_order_release);
6871
}
6972

@@ -144,9 +147,20 @@ void CudaOnlinePipelineDynamicBatcher::BatcherThreadLoop() {
144147
curr_batch_->corr_ids, curr_batch_->h_all_waveform,
145148
curr_batch_->n_samples_valid, curr_batch_->is_first_chunk,
146149
curr_batch_->is_last_chunk);
147-
n_chunks_not_done_.fetch_sub(curr_batch_->Size(),
148-
std::memory_order_release);
149150

151+
{
152+
// Update counts
153+
std::lock_guard<std::mutex> lk(next_batch_and_backlog_m_);
154+
n_chunks_not_done_.fetch_sub(curr_batch_->Size(),
155+
std::memory_order_release);
156+
for (size_t i = 0; i < curr_batch_->corr_ids.size(); ++i) {
157+
CorrelationID corr_id = curr_batch_->corr_ids[i];
158+
--n_chunks_per_corr_[corr_id];
159+
if (curr_batch_->is_last_chunk[i]) {
160+
n_chunks_per_corr_.erase(corr_id);
161+
}
162+
}
163+
}
150164
curr_batch_->Clear();
151165
}
152166

@@ -167,5 +181,12 @@ void CudaOnlinePipelineDynamicBatcher::WaitForCompletion() {
167181
cuda_pipeline_.WaitForLatticeCallbacks();
168182
}
169183

184+
int CudaOnlinePipelineDynamicBatcher::GetNumPendingChunks(CorrelationID corr_id) {
185+
std::lock_guard<std::mutex> lk(next_batch_and_backlog_m_);
186+
if (n_chunks_per_corr_.find(corr_id) == n_chunks_per_corr_.end())
187+
return 0;
188+
return n_chunks_per_corr_[corr_id];
189+
}
190+
170191
} // namespace cuda_decoder
171192
} // namespace kaldi

src/cudadecoder/cuda-online-pipeline-dynamic-batcher.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ class CudaOnlinePipelineDynamicBatcher {
5151
// return
5252
void Push(CorrelationID corr_id, bool is_first_chunk, bool is_last_chunk,
5353
const SubVector<BaseFloat> &wave_samples);
54+
55+
// Wait for completion of the submitted chunks
5456
void WaitForCompletion();
57+
// Get the number of unprocessed chunks for poll-like processing
58+
int GetNumPendingChunks(CorrelationID corr_id);
5559

5660
private:
5761
// Batches created by this Batcher
@@ -125,7 +129,9 @@ class CudaOnlinePipelineDynamicBatcher {
125129

126130
std::vector<const std::string *> partial_hypotheses_;
127131
std::vector<bool> end_points_;
132+
128133
std::atomic<std::uint32_t> n_chunks_not_done_;
134+
std::unordered_map<CorrelationID, int> n_chunks_per_corr_;
129135

130136
int max_batch_size_;
131137
int num_channels_;

0 commit comments

Comments
 (0)