Skip to content

Commit f1c9ae0

Browse files
btiplitzdanpovey
authored andcommitted
[src] Add nnet2 Chunking on GPU (#3761)
1 parent ba92f60 commit f1c9ae0

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

src/nnet2/nnet-compute-test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ void UnitTestNnetCompute() {
4444
return;
4545
CuMatrix<BaseFloat> output1(num_output_rows, output_dim);
4646
NnetComputation(*nnet, input, pad_input, &output1);
47-
4847
CuMatrix<BaseFloat> output2(output1.NumRows(), output1.NumCols());
4948
int32 cur_input_pos = 0, cur_output_pos = 0;
5049

@@ -98,11 +97,12 @@ void UnitTestNnetComputeChunked() {
9897

9998
int32 num_output_rows = num_feats;
10099
CuMatrix<BaseFloat> cu_output1(num_output_rows, output_dim);
101-
Matrix<BaseFloat> output2(num_output_rows, output_dim);
100+
CuMatrix<BaseFloat> cu_output2(num_output_rows, output_dim);
102101
NnetComputation(*nnet, input, pad_input, &cu_output1);
103-
NnetComputationChunked(*nnet, Matrix<BaseFloat>(input), chunk_size,
104-
&output2);
102+
NnetComputationChunked(*nnet, CuMatrix<BaseFloat>(input), chunk_size,
103+
&cu_output2);
105104
Matrix<BaseFloat> output1(cu_output1);
105+
Matrix<BaseFloat> output2(cu_output2);
106106
AssertEqual(output1, output2);
107107
for (int32 i = 0; i < output1.NumRows(); i++) {
108108
// just double-check that the frames near the end are right, in case

src/nnet2/nnet-compute.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,15 @@ void NnetComputation(const Nnet &nnet,
167167
}
168168

169169
void NnetComputationChunked(const Nnet &nnet,
170-
const Matrix<BaseFloat> &input, // features
170+
const CuMatrixBase<BaseFloat> &input, // features
171171
int32 chunk_size,
172-
Matrix<BaseFloat> *output) {
172+
CuMatrixBase<BaseFloat> *output) {
173173
int32 num_rows,
174174
num_chunks = ceil((BaseFloat)input.NumRows() / chunk_size),
175175
dim = input.NumCols(),
176176
left_context = nnet.LeftContext(),
177177
right_context = nnet.RightContext();
178-
Matrix<BaseFloat> full_input;
178+
CuMatrix<BaseFloat> full_input;
179179
num_rows = left_context + input.NumRows() + right_context;
180180
full_input.Resize(num_rows, dim);
181181
full_input.Range(left_context, input.NumRows(),
@@ -190,15 +190,15 @@ void NnetComputationChunked(const Nnet &nnet,
190190
int32 index = i * chunk_size,
191191
offset = std::min(num_rows - chunk_size * i,
192192
left_context + chunk_size + right_context);
193-
SubMatrix<BaseFloat> chunk_input(full_input, index, offset, 0, dim);
193+
CuSubMatrix<BaseFloat> chunk_input(full_input, index, offset, 0, dim);
194194
CuMatrix<BaseFloat> cu_chunk_input(chunk_input);
195195

196196
// Note: we have already accounted for input padding, so we pass
197197
// pad_input==false to the NnetComputer.
198198
NnetComputer nnet_computer(nnet, cu_chunk_input, false, NULL);
199199
nnet_computer.Propagate();
200200
CuMatrix<BaseFloat> cu_chunk_output(nnet_computer.GetOutput());
201-
SubMatrix<BaseFloat> chunk_out(*output, i * chunk_size,
201+
CuSubMatrix<BaseFloat> chunk_out(*output, i * chunk_size,
202202
cu_chunk_output.NumRows(), 0,
203203
cu_chunk_output.NumCols());
204204
chunk_out.CopyFromMat(cu_chunk_output);

src/nnet2/nnet-compute.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ void NnetComputation(const Nnet &nnet,
5656
input.NumRows().
5757
*/
5858
void NnetComputationChunked(const Nnet &nnet,
59-
const Matrix<BaseFloat> &input, // features
59+
const CuMatrixBase<BaseFloat> &input, // features
6060
int32 chunk_size,
61-
Matrix<BaseFloat> *output); // posteriors.
61+
CuMatrixBase<BaseFloat> *output); // posteriors.
6262

6363
/** Does the neural net computation and backprop, given input and labels.
6464
Note: if pad_input==true the number of rows of input should be the

src/nnet2bin/nnet-am-compute.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,15 @@ int main(int argc, char *argv[]) {
115115
}
116116

117117
Matrix<BaseFloat> output(output_frames, output_dim);
118-
if (chunk_size > 0 && chunk_size < feats.NumRows()) {
119-
NnetComputationChunked(nnet, feats, chunk_size, &output);
118+
CuMatrix<BaseFloat> cu_feats(feats);
119+
CuMatrix<BaseFloat> cu_output(output);
120+
if (chunk_size > 0 && chunk_size < feats.NumRows()) {
121+
NnetComputationChunked(nnet, cu_feats, chunk_size, &cu_output);
120122
} else {
121-
CuMatrix<BaseFloat> cu_feats(feats);
122-
CuMatrix<BaseFloat> cu_output(output);
123123
NnetComputation(nnet, cu_feats, pad_input, &cu_output);
124-
output.CopyFromMat(cu_output);
125124
}
126-
125+
cu_output.Swap(&output);
126+
127127
if (divide_by_priors) {
128128
output.MulColsVec(inv_priors); // scales each column by the corresponding element
129129
// of inv_priors.

0 commit comments

Comments
 (0)