Skip to content

Commit 8e2bbd2

Browse files
akshaysubrdanpovey
authored andcommitted
[src] Enable multiple threads for chain-generic-numerator to remove CPU bottleneck (#3766)
1 parent 2b30a1e commit 8e2bbd2

18 files changed

+190
-19
lines changed

src/chain/chain-denominator.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ void DenominatorComputation::AlphaFirstFrame() {
108108

109109
// the alpha computation for some 0 < t <= num_time_steps_.
110110
void DenominatorComputation::AlphaGeneralFrame(int32 t) {
111+
NVTX_RANGE(__func__);
111112
KALDI_ASSERT(t > 0 && t <= frames_per_sequence_);
112113
BaseFloat *this_alpha = alpha_.RowData(t);
113114
const BaseFloat *prev_alpha_dash = alpha_.RowData(t - 1);
@@ -186,6 +187,7 @@ void DenominatorComputation::AlphaGeneralFrame(int32 t) {
186187
}
187188

188189
void DenominatorComputation::AlphaDash(int32 t) {
190+
NVTX_RANGE(__func__);
189191
BaseFloat *this_alpha = alpha_.RowData(t);
190192

191193
// create a 'fake matrix' for the regular alphas- view this row as a matrix.
@@ -209,6 +211,7 @@ void DenominatorComputation::AlphaDash(int32 t) {
209211

210212
// compute beta from beta-dash.
211213
void DenominatorComputation::Beta(int32 t) {
214+
NVTX_RANGE(__func__);
212215
BaseFloat *this_beta_dash = beta_.RowData(t % 2);
213216
// create a 'fake matrix' for the regular beta-dash (which is
214217
// the counterpart of alpha-dash)- view this row as a matrix.
@@ -231,6 +234,7 @@ void DenominatorComputation::Beta(int32 t) {
231234
}
232235

233236
BaseFloat DenominatorComputation::Forward() {
237+
NVTX_RANGE(__func__);
234238
AlphaFirstFrame();
235239
AlphaDash(0);
236240
for (int32 t = 1; t <= frames_per_sequence_; t++) {
@@ -241,6 +245,7 @@ BaseFloat DenominatorComputation::Forward() {
241245
}
242246

243247
BaseFloat DenominatorComputation::ComputeTotLogLike() {
248+
NVTX_RANGE(__func__);
244249
tot_prob_.Resize(num_sequences_);
245250
// View the last alpha-dash as a matrix of size num-hmm-states by num-sequences.
246251
CuSubMatrix<BaseFloat> last_alpha_dash(
@@ -281,6 +286,7 @@ BaseFloat DenominatorComputation::ComputeTotLogLike() {
281286
bool DenominatorComputation::Backward(
282287
BaseFloat deriv_weight,
283288
CuMatrixBase<BaseFloat> *nnet_output_deriv) {
289+
NVTX_RANGE(__func__);
284290
BetaDashLastFrame();
285291
Beta(frames_per_sequence_);
286292
for (int32 t = frames_per_sequence_ - 1; t >= 0; t--) {
@@ -332,6 +338,7 @@ void DenominatorComputation::BetaDashLastFrame() {
332338
}
333339

334340
void DenominatorComputation::BetaDashGeneralFrame(int32 t) {
341+
NVTX_RANGE(__func__);
335342
KALDI_ASSERT(t >= 0 && t < frames_per_sequence_);
336343
int32 num_pdfs = exp_nnet_output_transposed_.NumRows();
337344
// t_wrapped gives us the time-index we use when indexing

src/chain/chain-generic-numerator.cc

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
#include <iterator>
2525
#include <limits>
26-
#include <algorithm>
2726

2827
namespace kaldi {
2928
namespace chain {
@@ -34,13 +33,16 @@ namespace chain {
3433
// for end-to-end training 'supervision's.
3534

3635
GenericNumeratorComputation::GenericNumeratorComputation(
36+
const GenericNumeratorComputationOptions &opts,
3737
const Supervision &supervision,
3838
const CuMatrixBase<BaseFloat> &nnet_output):
3939
supervision_(supervision),
40-
nnet_output_(nnet_output) {
40+
nnet_output_(nnet_output),
41+
opts_(opts) {
4142
KALDI_ASSERT(supervision.num_sequences *
4243
supervision.frames_per_sequence == nnet_output.NumRows() &&
4344
supervision.label_dim == nnet_output.NumCols());
45+
NVTX_RANGE(__func__);
4446

4547
using std::vector;
4648
int num_sequences = supervision_.num_sequences;
@@ -119,6 +121,7 @@ GenericNumeratorComputation::GenericNumeratorComputation(
119121

120122
void GenericNumeratorComputation::AlphaFirstFrame(int seq,
121123
Matrix<BaseFloat> *alpha) {
124+
NVTX_RANGE(__func__);
122125
const int32 num_frames = supervision_.frames_per_sequence,
123126
num_states = supervision_.e2e_fsts[seq].NumStates();
124127
alpha->Resize(num_frames + 1, num_states + 1, kSetZero);
@@ -133,6 +136,7 @@ void GenericNumeratorComputation::CopySpecificPdfsIndirect(
133136
const std::vector<MatrixIndexT> &indices,
134137
Matrix<BaseFloat> *out) {
135138
KALDI_ASSERT(nnet_output_stride_ == nnet_output_.Stride());
139+
NVTX_RANGE(__func__);
136140
const int32 num_sequences = supervision_.num_sequences,
137141
frames_per_sequence = supervision_.frames_per_sequence;
138142

@@ -156,6 +160,7 @@ void GenericNumeratorComputation::CopySpecificPdfsIndirect(
156160
BaseFloat GenericNumeratorComputation::AlphaRemainingFrames(int seq,
157161
const Matrix<BaseFloat> &probs,
158162
Matrix<BaseFloat> *alpha) {
163+
NVTX_RANGE(__func__);
159164
// Define some variables to make things nicer
160165
const int32 num_sequences = supervision_.num_sequences,
161166
num_frames = supervision_.frames_per_sequence;
@@ -212,6 +217,7 @@ BaseFloat GenericNumeratorComputation::AlphaRemainingFrames(int seq,
212217
bool GenericNumeratorComputation::ForwardBackward(
213218
BaseFloat *total_loglike,
214219
CuMatrixBase<BaseFloat> *nnet_output_deriv) {
220+
NVTX_RANGE(__func__);
215221
KALDI_ASSERT(total_loglike != NULL);
216222
KALDI_ASSERT(nnet_output_deriv != NULL);
217223
KALDI_ASSERT(nnet_output_deriv->NumCols() == nnet_output_.NumCols());
@@ -221,35 +227,71 @@ bool GenericNumeratorComputation::ForwardBackward(
221227
const int32 num_sequences = supervision_.num_sequences;
222228

223229
bool ok = true;
224-
Matrix<BaseFloat> alpha;
225-
Matrix<BaseFloat> beta;
226230
Matrix<BaseFloat> probs;
227-
Matrix<BaseFloat> derivs;
231+
Matrix<BaseFloat> derivs; // Don't need nthreads copies to avoid data
232+
// races since each sequence operates on a
233+
// distinct set of columns
228234

229235
// We selectively copy only those pdfs we need
230236
CopySpecificPdfsIndirect(nnet_output_, index_to_pdf_, &probs);
231237

232238
derivs.Resize(probs.NumRows(), probs.NumCols());
233239
derivs.Set(-std::numeric_limits<BaseFloat>::infinity());
234240

235-
for (int seq = 0; seq < num_sequences; ++seq) {
236-
// Forward part
237-
AlphaFirstFrame(seq, &alpha);
238-
partial_loglike += AlphaRemainingFrames(seq, probs, &alpha);
239-
240-
// Backward part
241-
BetaLastFrame(seq, alpha, &beta);
242-
BetaRemainingFrames(seq, probs, alpha, &beta, &derivs);
243-
if (GetVerboseLevel() >= 1)
244-
ok = ok && CheckValues(seq, probs, alpha, beta, derivs);
241+
// Set total number of workers to the available hardware concurrency
242+
unsigned int nthreads = opts_.num_threads > 0 ? opts_.num_threads :
243+
std::thread::hardware_concurrency();
244+
// Naive load balancing, each thread gets a chunk of the sequences to process
245+
unsigned int num_sequences_per_thread =
246+
(num_sequences + nthreads - 1) / nthreads;
247+
248+
// Allocate one alpha and beta matrix per thread to avoid contention
249+
std::vector<Matrix<BaseFloat>> alpha(nthreads);
250+
std::vector<Matrix<BaseFloat>> beta(nthreads);
251+
252+
// Per thread partial values and boolean
253+
std::vector<BaseFloat> partial_loglike_mt(nthreads, static_cast<BaseFloat>(0));
254+
std::vector<bool> ok_mt(nthreads, true);
255+
256+
// Lambda function for each thread's portion of the computation
257+
auto thread_lambda = [&] (int thread, int num_sequences, int num_sequences_per_thread) {
258+
int seq_st = thread * num_sequences_per_thread;
259+
int seq_en = seq_st + num_sequences_per_thread;
260+
seq_en = (seq_en <= num_sequences) ? seq_en : num_sequences;
261+
for (int seq = seq_st; seq < seq_en; ++seq) {
262+
// Forward part
263+
AlphaFirstFrame(seq, &alpha[thread]);
264+
partial_loglike_mt[thread] += AlphaRemainingFrames(seq, probs, &alpha[thread]);
265+
266+
// Backward part
267+
BetaLastFrame(seq, alpha[thread], &beta[thread]);
268+
BetaRemainingFrames(seq, probs, alpha[thread], &beta[thread], &derivs);
269+
if (GetVerboseLevel() >= 1)
270+
ok_mt[thread] = ok_mt[thread] && CheckValues(seq, probs, alpha[thread], beta[thread], derivs);
271+
}
272+
return;
273+
};
274+
275+
std::vector<std::thread> workers(nthreads);
276+
for (int thread = 0; thread < nthreads; ++thread)
277+
// Launch all threads
278+
workers[thread] = std::thread(thread_lambda, thread, num_sequences, num_sequences_per_thread);
279+
for (int thread = 0; thread < nthreads; ++thread) {
280+
// Join threads back in
281+
workers[thread].join();
282+
// Reduce thread values to a single value
283+
partial_loglike += partial_loglike_mt[thread];
284+
ok = ok && ok_mt[thread];
245285
}
286+
246287
// Transfer and add the derivatives to the values in the matrix
247288
AddSpecificPdfsIndirect(&derivs, index_to_pdf_, nnet_output_deriv);
248289
*total_loglike = partial_loglike;
249290
return ok;
250291
}
251292

252293
BaseFloat GenericNumeratorComputation::ComputeObjf() {
294+
NVTX_RANGE(__func__);
253295
BaseFloat partial_loglike = 0;
254296
const int32 num_sequences = supervision_.num_sequences;
255297

@@ -275,6 +317,7 @@ BaseFloat GenericNumeratorComputation::GetTotalProb(
275317
void GenericNumeratorComputation::BetaLastFrame(int seq,
276318
const Matrix<BaseFloat> &alpha,
277319
Matrix<BaseFloat> *beta) {
320+
NVTX_RANGE(__func__);
278321
// Sets up the beta quantity on the last frame (frame ==
279322
// frames_per_sequence_). Note that the betas we use here contain a
280323
// 1/(tot-prob) factor in order to simplify the backprop.
@@ -298,6 +341,7 @@ void GenericNumeratorComputation::BetaRemainingFrames(int seq,
298341
const Matrix<BaseFloat> &alpha,
299342
Matrix<BaseFloat> *beta,
300343
Matrix<BaseFloat> *derivs) {
344+
NVTX_RANGE(__func__);
301345
const int32
302346
num_sequences = supervision_.num_sequences,
303347
num_frames = supervision_.frames_per_sequence,
@@ -340,6 +384,7 @@ void GenericNumeratorComputation::AddSpecificPdfsIndirect(
340384
Matrix<BaseFloat> *logprobs,
341385
const std::vector<MatrixIndexT> &indices,
342386
CuMatrixBase<BaseFloat> *output) {
387+
NVTX_RANGE(__func__);
343388
const int32 num_sequences = supervision_.num_sequences,
344389
frames_per_sequence = supervision_.frames_per_sequence;
345390

src/chain/chain-generic-numerator.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
#include <vector>
2727
#include <map>
28+
#include <algorithm>
29+
#include <thread>
2830

2931
#include "base/kaldi-common.h"
3032
#include "util/common-utils.h"
@@ -102,6 +104,20 @@ namespace chain {
102104
*/
103105

104106

107+
struct GenericNumeratorComputationOptions {
108+
unsigned int num_threads;
109+
GenericNumeratorComputationOptions() :
110+
num_threads(std::min(static_cast<unsigned int>(4),
111+
std::thread::hardware_concurrency())) { }
112+
void Register(OptionsItf *opts) {
113+
opts->Register("numerator-graph-threads", &num_threads, "Number of threads "
114+
"to use to parallelize the chain numerator graph computation. "
115+
"If 0, use available hardware concurrency.");
116+
}
117+
118+
};
119+
120+
105121
// This class is responsible for the forward-backward of the
106122
// end-to-end 'supervision' (numerator) FST. This kind of FST can
107123
// have self-loops.
@@ -112,7 +128,8 @@ namespace chain {
112128
class GenericNumeratorComputation {
113129
public:
114130
/// Initializes the object.
115-
GenericNumeratorComputation(const Supervision &supervision,
131+
GenericNumeratorComputation(const GenericNumeratorComputationOptions &opts,
132+
const Supervision &supervision,
116133
const CuMatrixBase<BaseFloat> &nnet_output);
117134

118135
// Does the forward-backward computation. Returns the total log-prob
@@ -198,6 +215,9 @@ class GenericNumeratorComputation {
198215
// an offset subtracted from the logprobs of transitions out of the first
199216
// state of each graph to help reduce numerical problems.
200217
Vector<BaseFloat> offsets_;
218+
219+
// Configuration options
220+
const GenericNumeratorComputationOptions &opts_;
201221
};
202222

203223
} // namespace chain

src/chain/chain-training.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ void ComputeChainObjfAndDerivE2e(const ChainTrainingOptions &opts,
9494
BaseFloat *weight,
9595
CuMatrixBase<BaseFloat> *nnet_output_deriv,
9696
CuMatrix<BaseFloat> *xent_output_deriv) {
97+
NVTX_RANGE(__func__);
9798
BaseFloat num_logprob_weighted, den_logprob_weighted;
9899
bool denominator_ok = true;
99100
bool numerator_ok = true;
@@ -136,7 +137,8 @@ void ComputeChainObjfAndDerivE2e(const ChainTrainingOptions &opts,
136137

137138

138139
{
139-
GenericNumeratorComputation numerator(supervision, nnet_output);
140+
GenericNumeratorComputation numerator(opts.numerator_opts,
141+
supervision, nnet_output);
140142
// note: supervision.weight is included as a factor in the derivative from
141143
// the numerator object, as well as the returned logprob.
142144
if (xent_output_deriv) {
@@ -211,6 +213,7 @@ void ComputeChainObjfAndDeriv(const ChainTrainingOptions &opts,
211213
BaseFloat *weight,
212214
CuMatrixBase<BaseFloat> *nnet_output_deriv,
213215
CuMatrix<BaseFloat> *xent_output_deriv) {
216+
NVTX_RANGE(__func__);
214217
if (!supervision.e2e_fsts.empty()) {
215218
ComputeChainObjfAndDerivE2e(opts, den_graph, supervision,
216219
nnet_output, objf, l2_term,

src/chain/chain-training.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "hmm/transition-model.h"
3535
#include "chain/chain-den-graph.h"
3636
#include "chain/chain-supervision.h"
37+
#include "chain/chain-generic-numerator.h"
3738

3839
namespace kaldi {
3940
namespace chain {
@@ -93,7 +94,12 @@ struct ChainTrainingOptions {
9394
"nonzero, the network is expected to have an output "
9495
"named 'output-xent', which should have a softmax as "
9596
"its final nonlinearity.");
97+
98+
numerator_opts.Register(opts);
9699
}
100+
101+
// Config for numerator graph object
102+
GenericNumeratorComputationOptions numerator_opts;
97103
};
98104

99105

src/chainbin/nnet3-chain-train.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "nnet3/nnet-chain-training.h"
2323
#include "cudamatrix/cu-allocator.h"
2424

25-
2625
int main(int argc, char *argv[]) {
2726
try {
2827
using namespace kaldi;
@@ -53,6 +52,9 @@ int main(int argc, char *argv[]) {
5352
"yes|no|optional|wait, only has effect if compiled with CUDA");
5453

5554
opts.Register(&po);
55+
#if HAVE_CUDA==1
56+
CuDevice::RegisterDeviceOptions(&po);
57+
#endif
5658
RegisterCuAllocatorOptions(&po);
5759

5860
po.Read(argc, argv);

src/cudamatrix/cu-common.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,30 @@
3131
#include "cudamatrix/cu-common.h"
3232
#include "cudamatrix/cu-matrixdim.h"
3333

34-
3534
namespace kaldi {
3635

3736
#if HAVE_CUDA == 1
37+
38+
#ifdef USE_NVTX
39+
NvtxTracer::NvtxTracer(const char* name) {
40+
const uint32_t colors[] = { 0xff00ff00, 0xff0000ff, 0xffffff00, 0xffff00ff, 0xff00ffff, 0xffff0000, 0xffffffff };
41+
const int num_colors = sizeof(colors)/sizeof(uint32_t);
42+
int color_id = ((int)name[0])%num_colors;
43+
nvtxEventAttributes_t eventAttrib = {0};
44+
eventAttrib.version = NVTX_VERSION;
45+
eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
46+
eventAttrib.colorType = NVTX_COLOR_ARGB;
47+
eventAttrib.color = colors[color_id];
48+
eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
49+
eventAttrib.message.ascii = name;
50+
nvtxRangePushEx(&eventAttrib);
51+
// nvtxRangePushA(name);
52+
}
53+
NvtxTracer::~NvtxTracer() {
54+
nvtxRangePop();
55+
}
56+
#endif
57+
3858
cublasOperation_t KaldiTransToCuTrans(MatrixTransposeType kaldi_trans) {
3959
cublasOperation_t cublas_trans;
4060

0 commit comments

Comments
 (0)