2323
2424#include < iterator>
2525#include < limits>
26- #include < algorithm>
2726
2827namespace kaldi {
2928namespace chain {
@@ -34,13 +33,16 @@ namespace chain {
3433// for end-to-end training 'supervision's.
3534
3635GenericNumeratorComputation::GenericNumeratorComputation (
36+ const GenericNumeratorComputationOptions &opts,
3737 const Supervision &supervision,
3838 const CuMatrixBase<BaseFloat> &nnet_output):
3939 supervision_ (supervision),
40- nnet_output_ (nnet_output) {
40+ nnet_output_ (nnet_output),
41+ opts_ (opts) {
4142 KALDI_ASSERT (supervision.num_sequences *
4243 supervision.frames_per_sequence == nnet_output.NumRows () &&
4344 supervision.label_dim == nnet_output.NumCols ());
45+ NVTX_RANGE (__func__);
4446
4547 using std::vector;
4648 int num_sequences = supervision_.num_sequences ;
@@ -119,6 +121,7 @@ GenericNumeratorComputation::GenericNumeratorComputation(
119121
120122void GenericNumeratorComputation::AlphaFirstFrame (int seq,
121123 Matrix<BaseFloat> *alpha) {
124+ NVTX_RANGE (__func__);
122125 const int32 num_frames = supervision_.frames_per_sequence ,
123126 num_states = supervision_.e2e_fsts [seq].NumStates ();
124127 alpha->Resize (num_frames + 1 , num_states + 1 , kSetZero );
@@ -133,6 +136,7 @@ void GenericNumeratorComputation::CopySpecificPdfsIndirect(
133136 const std::vector<MatrixIndexT> &indices,
134137 Matrix<BaseFloat> *out) {
135138 KALDI_ASSERT (nnet_output_stride_ == nnet_output_.Stride ());
139+ NVTX_RANGE (__func__);
136140 const int32 num_sequences = supervision_.num_sequences ,
137141 frames_per_sequence = supervision_.frames_per_sequence ;
138142
@@ -156,6 +160,7 @@ void GenericNumeratorComputation::CopySpecificPdfsIndirect(
156160BaseFloat GenericNumeratorComputation::AlphaRemainingFrames (int seq,
157161 const Matrix<BaseFloat> &probs,
158162 Matrix<BaseFloat> *alpha) {
163+ NVTX_RANGE (__func__);
159164 // Define some variables to make things nicer
160165 const int32 num_sequences = supervision_.num_sequences ,
161166 num_frames = supervision_.frames_per_sequence ;
@@ -212,6 +217,7 @@ BaseFloat GenericNumeratorComputation::AlphaRemainingFrames(int seq,
212217bool GenericNumeratorComputation::ForwardBackward (
213218 BaseFloat *total_loglike,
214219 CuMatrixBase<BaseFloat> *nnet_output_deriv) {
220+ NVTX_RANGE (__func__);
215221 KALDI_ASSERT (total_loglike != NULL );
216222 KALDI_ASSERT (nnet_output_deriv != NULL );
217223 KALDI_ASSERT (nnet_output_deriv->NumCols () == nnet_output_.NumCols ());
@@ -221,35 +227,71 @@ bool GenericNumeratorComputation::ForwardBackward(
221227 const int32 num_sequences = supervision_.num_sequences ;
222228
223229 bool ok = true ;
224- Matrix<BaseFloat> alpha;
225- Matrix<BaseFloat> beta;
226230 Matrix<BaseFloat> probs;
227- Matrix<BaseFloat> derivs;
231+ Matrix<BaseFloat> derivs; // Don't need nthreads copies to avoid data
232+ // races since each sequence operates on a
233+ // distinct set of columns
228234
229235 // We selectively copy only those pdfs we need
230236 CopySpecificPdfsIndirect (nnet_output_, index_to_pdf_, &probs);
231237
232238 derivs.Resize (probs.NumRows (), probs.NumCols ());
233239 derivs.Set (-std::numeric_limits<BaseFloat>::infinity ());
234240
235- for (int seq = 0 ; seq < num_sequences; ++seq) {
236- // Forward part
237- AlphaFirstFrame (seq, &alpha);
238- partial_loglike += AlphaRemainingFrames (seq, probs, &alpha);
239-
240- // Backward part
241- BetaLastFrame (seq, alpha, &beta);
242- BetaRemainingFrames (seq, probs, alpha, &beta, &derivs);
243- if (GetVerboseLevel () >= 1 )
244- ok = ok && CheckValues (seq, probs, alpha, beta, derivs);
241+ // Set total number of workers to the available hardware concurrency
242+ unsigned int nthreads = opts_.num_threads > 0 ? opts_.num_threads :
243+ std::thread::hardware_concurrency ();
244+ // Naive load balancing, each thread gets a chunk of the sequences to process
245+ unsigned int num_sequences_per_thread =
246+ (num_sequences + nthreads - 1 ) / nthreads;
247+
248+ // Allocate one alpha and beta matrix per thread to avoid contention
249+ std::vector<Matrix<BaseFloat>> alpha (nthreads);
250+ std::vector<Matrix<BaseFloat>> beta (nthreads);
251+
252+ // Per thread partial values and boolean
253+ std::vector<BaseFloat> partial_loglike_mt (nthreads, static_cast <BaseFloat>(0 ));
254+ std::vector<bool > ok_mt (nthreads, true );
255+
256+ // Lambda function for each thread's portion of the computation
257+ auto thread_lambda = [&] (int thread, int num_sequences, int num_sequences_per_thread) {
258+ int seq_st = thread * num_sequences_per_thread;
259+ int seq_en = seq_st + num_sequences_per_thread;
260+ seq_en = (seq_en <= num_sequences) ? seq_en : num_sequences;
261+ for (int seq = seq_st; seq < seq_en; ++seq) {
262+ // Forward part
263+ AlphaFirstFrame (seq, &alpha[thread]);
264+ partial_loglike_mt[thread] += AlphaRemainingFrames (seq, probs, &alpha[thread]);
265+
266+ // Backward part
267+ BetaLastFrame (seq, alpha[thread], &beta[thread]);
268+ BetaRemainingFrames (seq, probs, alpha[thread], &beta[thread], &derivs);
269+ if (GetVerboseLevel () >= 1 )
270+ ok_mt[thread] = ok_mt[thread] && CheckValues (seq, probs, alpha[thread], beta[thread], derivs);
271+ }
272+ return ;
273+ };
274+
275+ std::vector<std::thread> workers (nthreads);
276+ for (int thread = 0 ; thread < nthreads; ++thread)
277+ // Launch all threads
278+ workers[thread] = std::thread (thread_lambda, thread, num_sequences, num_sequences_per_thread);
279+ for (int thread = 0 ; thread < nthreads; ++thread) {
280+ // Join threads back in
281+ workers[thread].join ();
282+ // Reduce thread values to a single value
283+ partial_loglike += partial_loglike_mt[thread];
284+ ok = ok && ok_mt[thread];
245285 }
286+
246287 // Transfer and add the derivatives to the values in the matrix
247288 AddSpecificPdfsIndirect (&derivs, index_to_pdf_, nnet_output_deriv);
248289 *total_loglike = partial_loglike;
249290 return ok;
250291}
251292
252293BaseFloat GenericNumeratorComputation::ComputeObjf () {
294+ NVTX_RANGE (__func__);
253295 BaseFloat partial_loglike = 0 ;
254296 const int32 num_sequences = supervision_.num_sequences ;
255297
@@ -275,6 +317,7 @@ BaseFloat GenericNumeratorComputation::GetTotalProb(
275317void GenericNumeratorComputation::BetaLastFrame (int seq,
276318 const Matrix<BaseFloat> &alpha,
277319 Matrix<BaseFloat> *beta) {
320+ NVTX_RANGE (__func__);
278321 // Sets up the beta quantity on the last frame (frame ==
279322 // frames_per_sequence_). Note that the betas we use here contain a
280323 // 1/(tot-prob) factor in order to simplify the backprop.
@@ -298,6 +341,7 @@ void GenericNumeratorComputation::BetaRemainingFrames(int seq,
298341 const Matrix<BaseFloat> &alpha,
299342 Matrix<BaseFloat> *beta,
300343 Matrix<BaseFloat> *derivs) {
344+ NVTX_RANGE (__func__);
301345 const int32
302346 num_sequences = supervision_.num_sequences ,
303347 num_frames = supervision_.frames_per_sequence ,
@@ -340,6 +384,7 @@ void GenericNumeratorComputation::AddSpecificPdfsIndirect(
340384 Matrix<BaseFloat> *logprobs,
341385 const std::vector<MatrixIndexT> &indices,
342386 CuMatrixBase<BaseFloat> *output) {
387+ NVTX_RANGE (__func__);
343388 const int32 num_sequences = supervision_.num_sequences ,
344389 frames_per_sequence = supervision_.frames_per_sequence ;
345390
0 commit comments