Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/cudadecoder/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ OBJFILES = cuda-decoder.o cuda-decoder-kernels.o cuda-fst.o \
batched-threaded-nnet3-cuda-pipeline2.o \
batched-static-nnet3.o batched-static-nnet3-kernels.o \
cuda-online-pipeline-dynamic-batcher.o decodable-cumatrix.o \
cuda-pipeline-common.o lattice-postprocessor.o
cuda-pipeline-common.o lattice-postprocessor.o \
thread-pool-cia.o

LIBNAME = kaldi-cudadecoder

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::CompactWavesToMatrix(

for (size_t i = 0; i < wave_samples.size(); i += batch_size) {

auto task = [i, this, &wave_samples, &m, &cv, &tasks_remaining, &batch_size](void *ignore1, uint64_t ignore2, void *ignore3) {
auto task = [i, this, &wave_samples, &m, &cv, &tasks_remaining, &batch_size]() {
nvtxRangePush("CompactWavesToMatrix task");
for (size_t j = i; j < std::min(i + batch_size, wave_samples.size()); ++j) {
const SubVector<BaseFloat> &src = wave_samples[j];
Expand All @@ -281,7 +281,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::CompactWavesToMatrix(
}
nvtxRangePop();
};
batching_copy_thread_pool_->Push({task, nullptr, 0, nullptr});
batching_copy_thread_pool_->submit(task);
}

// wait for all threads to finish
Expand Down
6 changes: 4 additions & 2 deletions src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
#include "nnet3/nnet-optimize.h"
#include "online2/online-nnet2-feature-pipeline.h"

#include "cudadecoder/thread-pool-cia.h"

namespace kaldi {
namespace cuda_decoder {

Expand Down Expand Up @@ -165,7 +167,7 @@ class BatchedThreadedNnet3CudaOnlinePipeline {

int num_batching_copy_threads = config_.num_batching_copy_threads;
if (num_batching_copy_threads > 0) {
batching_copy_thread_pool_ = std::make_unique<ThreadPoolLight>(num_batching_copy_threads);
batching_copy_thread_pool_ = std::make_unique<work_stealing_thread_pool>(num_batching_copy_threads);
}

Initialize(decode_fst);
Expand Down Expand Up @@ -519,7 +521,7 @@ class BatchedThreadedNnet3CudaOnlinePipeline {
// destructor blocks until the thread pool is drained of work items.
std::unique_ptr<ThreadPoolLight> thread_pool_;

std::unique_ptr<ThreadPoolLight> batching_copy_thread_pool_;
std::unique_ptr<work_stealing_thread_pool> batching_copy_thread_pool_;

// The decoder owns thread(s) that reconstruct lattices transferred from the
// device in a compacted form as arrays with offsets instead of pointers.
Expand Down
6 changes: 6 additions & 0 deletions src/cudadecoder/thread-pool-cia.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <cudadecoder/thread-pool-cia.h>

namespace kaldi {
thread_local work_stealing_queue* work_stealing_thread_pool::local_work_queue;
thread_local unsigned int work_stealing_thread_pool::my_index;
}
Loading