Skip to content

Commit 9796afd

Browse files
authored
[src] Shutdown CUDA online pipeline deterministically (#4570)
Also reformat string literals that ended up chopped into too many pieces over too many lines, and add commentary on the importance of data member ordering within the class.
1 parent d0937dc commit 9796afd

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
// See the License for the specific language governing permissions and
1616
// limitations under the License.
1717

18-
#ifndef KALDI_CUDADECODER_BATCHED_THREADED_CUDA_ONLINE_PIPELINE_H_
19-
#define KALDI_CUDADECODER_BATCHED_THREADED_CUDA_ONLINE_PIPELINE_H_
18+
#ifndef KALDI_CUDADECODER_BATCHED_THREADED_NNET3_CUDA_ONLINE_PIPELINE_H_
19+
#define KALDI_CUDADECODER_BATCHED_THREADED_NNET3_CUDA_ONLINE_PIPELINE_H_
2020

2121
#if HAVE_CUDA
2222

@@ -66,24 +66,22 @@ struct BatchedThreadedNnet3CudaOnlinePipelineConfig {
6666
use_gpu_feature_extraction(true) {}
6767
void Register(OptionsItf *po) {
6868
po->Register("max-batch-size", &max_batch_size,
69-
"The maximum execution batch size. "
70-
"Larger = Better throughput slower latency.");
69+
"The maximum execution batch size."
70+
" Larger = better throughput, but slower latency.");
7171
po->Register("num-channels", &num_channels,
72-
"The number of parallel audio channels. This is the maximum "
73-
"number of parallel audio channels supported by the pipeline"
74-
". This should be larger "
75-
"than max_batch_size.");
72+
"The number of parallel audio channels. This is the maximum"
73+
" number of parallel audio channels supported by the pipeline."
74+
" This should be larger than max_batch_size.");
7675
po->Register("cuda-worker-threads", &num_worker_threads,
77-
"(optional) The total number of CPU threads launched to "
78-
"process CPU tasks. -1 = use std::hardware_concurrency()");
76+
"The total number of CPU threads launched to process CPU"
77+
" tasks. -1 = use std::hardware_concurrency().");
7978
po->Register("determinize-lattice", &determinize_lattice,
8079
"Determinize the lattice before output.");
8180
po->Register("cuda-decoder-copy-threads", &num_decoder_copy_threads,
82-
"Advanced - Number of worker threads used in the "
83-
"decoder for "
84-
"the host to host copies.");
81+
"Advanced - Number of worker threads used in the"
82+
" decoder for the host to host copies.");
8583
po->Register("gpu-feature-extract", &use_gpu_feature_extraction,
86-
"Use GPU feature extraction");
84+
"Use GPU feature extraction.");
8785

8886
feature_opts.Register(po);
8987
decoder_opts.Register(po);
@@ -138,10 +136,9 @@ class BatchedThreadedNnet3CudaOnlinePipeline {
138136
word_syms_(NULL) {
139137
config_.compute_opts.CheckAndFixConfigs(am_nnet_->GetNnet().Modulus());
140138
config_.CheckAndFixConfigs();
141-
int num_worker_threads = config_.num_worker_threads;
142-
thread_pool_.reset(new ThreadPoolLight(num_worker_threads));
143-
144139
Initialize(decode_fst);
140+
int num_worker_threads = config.num_worker_threads;
141+
thread_pool_ = std::make_unique<ThreadPoolLight>(num_worker_threads);
145142
}
146143

147144
~BatchedThreadedNnet3CudaOnlinePipeline();
@@ -415,22 +412,35 @@ class BatchedThreadedNnet3CudaOnlinePipeline {
415412
// Only used if feature extraction is run on the CPU
416413
std::vector<std::unique_ptr<OnlineNnet2FeaturePipeline>> feature_pipelines_;
417414

418-
// HCLG graph : CudaFst object is a host object, but contains
419-
// data stored in
420-
// GPU memory
415+
// Ordering of the cuda_fst_ w.r.t. thread_pool_ and the decoder is important:
416+
// order of destruction is bottom-up, opposite to the order of construction.
417+
// We want the FST object, which is entirely passive and only frees device
418+
// FST representation when destroyed, to survive both the thread pool and the
419+
// decoder, which both may perform pending work during destruction. Since no
420+
// new work may be fed into this object while it is being destroyed, the
421+
// relative order of the latter two is unimportant, but just in case, FST must
422+
// stay around until the other two are positively quiescent.
423+
424+
// HCLG graph. CudaFst is a host object, but owns pointers to the data stored
425+
// in GPU memory.
421426
std::unique_ptr<CudaFst> cuda_fst_;
422-
std::unique_ptr<CudaDecoder> cuda_decoder_;
423427

428+
// The thread pool receives data from device and post-processes it. This class
429+
// destructor blocks until the thread pool is drained of work items.
424430
std::unique_ptr<ThreadPoolLight> thread_pool_;
425431

432+
// The decoder owns thread(s) that reconstruct lattices transferred from the
433+
// device in a compacted form as arrays with offsets instead of pointers.
434+
std::unique_ptr<CudaDecoder> cuda_decoder_;
435+
426436
// Used for debugging
427437
const fst::SymbolTable *word_syms_;
428438
// Used when printing to stdout for debugging purposes
429439
std::mutex stdout_m_;
430440
};
431441

432-
} // end namespace cuda_decoder
433-
} // end namespace kaldi.
442+
} // namespace cuda_decoder
443+
} // namespace kaldi
434444

435445
#endif // HAVE_CUDA
436-
#endif // KALDI_CUDADECODER_BATCHED_THREADED_CUDA_ONLINE_PIPELINE_H_
446+
#endif // KALDI_CUDADECODER_BATCHED_THREADED_NNET3_CUDA_ONLINE_PIPELINE_H_

0 commit comments

Comments
 (0)