2121#include " cudamatrix/cu-common.h"
2222namespace kaldi {
2323
24- // computes feats^2. This works in place and out of place.
24+ // computes pointwise square of each matrix
2525__global__ void square_batched_matrix_kernel (
2626 int32_t chunk_frames, int32_t num_cols, const float *feats, int32_t ldf,
2727 int32_t stridef, float *feats_sq, int32_t lds, int32_t strides,
2828 const LaneDesc *lanes, int32_t num_lanes) {
2929 int32_t lane = blockIdx .z ;
30- int32_t num_chunk_frames = lanes[lane].num_chunk_frames ;
3130
3231 feats = feats + lane * stridef;
3332 feats_sq = feats_sq + lane * strides;
3433
35- for (int i = blockIdx .y * blockDim .y + threadIdx .y ; i < num_chunk_frames ;
34+ for (int i = blockIdx .y * blockDim .y + threadIdx .y ; i < chunk_frames ;
3635 i += blockDim .y * gridDim .y ) {
3736 for (int j = blockIdx .x * blockDim .x + threadIdx .x ; j < num_cols;
3837 j += blockDim .x * gridDim .x ) {
@@ -56,6 +55,55 @@ void square_batched_matrix(int32_t chunk_frames, int32_t num_cols,
5655 CU_SAFE_CALL (cudaGetLastError ());
5756}
5857
58+ // after computing posteriors some rows are invalid because they were created
59+ // with rows with undefined data. This kernel zeros those rows out so that
60+ // they will not contribue to stats.
61+ __global__ void zero_invalid_posteriors_kernel (
62+ int32_t chunk_size, int32_t num_gauss, float *posteriors, int32_t ldp,
63+ int32_t stridep, int32_t right, const LaneDesc *lanes, int32_t num_lanes) {
64+ int32_t lane = blockIdx .z ;
65+
66+ LaneDesc desc = lanes[lane];
67+ int32_t num_chunk_frames = desc.num_chunk_frames ;
68+ int32_t current_frame = desc.current_frame ;
69+ bool last = desc.last ;
70+
71+ // last valid frame for reading
72+ int32_t num_computed_rows = current_frame + num_chunk_frames;
73+
74+ // if not the last frame remove right context
75+ if (!last) {
76+ num_computed_rows -= right;
77+ }
78+
79+ // offset by lane
80+ posteriors = posteriors + lane * stridep;
81+
82+ for (int r = blockIdx .y * blockDim .y + threadIdx .y ; r < chunk_size;
83+ r += blockDim .y * gridDim .y ) {
84+ int global_row = current_frame + r - right;
85+ if (global_row < 0 || global_row >= num_computed_rows) {
86+ // zero this row out
87+ for (int c = blockIdx .x * blockDim .x + threadIdx .x ; c < num_gauss;
88+ c += blockDim .x * gridDim .x ) {
89+ posteriors[r * ldp + c] = 0 .0f ;
90+ }
91+ }
92+ }
93+ }
94+
95+ void zero_invalid_posteriors (int32_t num_chunk_frames, int32_t num_gauss,
96+ float *posteriors, int32_t ldp, int32_t stridep,
97+ int32_t right, const LaneDesc *lanes,
98+ int32_t num_lanes) {
99+ dim3 threads (32 , 32 );
100+ dim3 blocks ((num_gauss + 31 ) / 32 , (num_chunk_frames + 31 ) / 32 , num_lanes);
101+
102+ zero_invalid_posteriors_kernel<<<blocks, threads>>> (
103+ num_chunk_frames, num_gauss, posteriors, ldp, stridep, right, lanes,
104+ num_lanes);
105+ }
106+
59107// Meant to be called with blockDim= 32x32
60108// takes features in feat and writes them into sfeats while applying
61109// the splicing algorithm for the left and right context.
@@ -67,39 +115,48 @@ __global__ void splice_features_batched_kernel(
67115 float *__restrict__ feats_out, int32_t ldo, int32_t strideo,
68116 const LaneDesc *lanes, int32_t num_lanes) {
69117 int32_t lane = blockIdx .y ;
70- int32_t frame = blockIdx .x ;
118+ // output frame index
119+ int32_t oframe = blockIdx .x ;
71120 int32_t tid = threadIdx .x ;
72121
73122 LaneDesc desc = lanes[lane];
74123 int32_t num_chunk_frames = desc.num_chunk_frames ;
75124 int32_t channel = desc.channel ;
76- int32_t start_frame = desc.current_frame ;
125+ int32_t current_frame = desc.current_frame ;
126+ bool last = desc.last ;
77127
78- bool valid_frame = true ;
79- // check that we have valid input
80- if (frame >= num_chunk_frames) {
81- valid_frame = false ;
82- }
128+ // offset by lane
129+ feats_in = feats_in + lane * stridei;
130+ feats_out = feats_out + lane * strideo;
83131
84- // for first chunk we process less frames
85- if (start_frame == 0 && frame >= num_chunk_frames - right) {
86- valid_frame = false ;
87- }
132+ // offset by channel
133+ feats_stash = feats_stash + channel * stridest;
88134
89- // the stash size
135+ // offset feature output to process oframe
136+ feats_out = feats_out + ldo * oframe;
137+
138+ // the size of the stash
90139 int32_t ssize = left + right;
140+ // the size of the window
91141 int32_t size = ssize + 1 ;
92142
93- // offset by lane
94- feats_in = feats_in + lane * stridei;
95- feats_out = feats_out + lane * strideo;
96- feats_stash = feats_stash + channel * stridest;
143+ // number of valid frame for reading
144+ int32_t num_valid_frames = current_frame + num_chunk_frames;
97145
98- // offset feature output to process frame
99- feats_out = feats_out + ldo * frame ;
146+ // number of valid frames for writing
147+ int32_t num_computed_frames = num_valid_frames ;
100148
101- if (!valid_frame) {
102- // this frames output is not valid, zero it here
149+ // if not the last frame remove right context
150+ if (!last) {
151+ num_computed_frames -= right;
152+ }
153+
154+ // subtract right context from logical frame to delay output
155+ int32_t local_frame = oframe - right;
156+ int32_t global_frame = current_frame + local_frame;
157+
158+ // these frames are set to zeros
159+ if (global_frame < 0 || global_frame >= num_computed_frames) {
103160 for (int i = 0 ; i < size; i++) {
104161 for (int c = tid; c < feat_dim; c += blockDim .x ) {
105162 feats_out[i * feat_dim + c] = 0 .0f ;
@@ -108,44 +165,40 @@ __global__ void splice_features_batched_kernel(
108165 return ;
109166 }
110167
111- // for each splice of input
112- for (int i = 0 ; i < size; i++) {
113- const float *feats_src = feats_in;
114- int32_t ld = ldi;
115-
116- // shift input row by left context
117- int r = frame + i - left;
168+ for (int i = -left; i <= right; i++) {
169+ int32_t g_in = global_frame + i; // global frame index
170+ int32_t l_in = local_frame + i; // local frame index
118171
119- // clamp input row if necessary
120- if (start_frame + r < 0 ) {
121- r = 0 ;
122- }
172+ // if global row is below zero clamp local to zero
173+ if (g_in < 0 ) l_in = 0 ;
123174
124- // if we have a right context shift input row by that too
125- if (start_frame > 0 ) {
126- r = r - right;
175+ // if global row is larger than the number of valid frames
176+ if (g_in >= num_valid_frames) {
177+ // should only happen on last chunk
178+ assert (last);
179+ // clamp input
180+ l_in = num_chunk_frames - 1 ;
127181 }
128182
129- if (r > num_chunk_frames - 1 ) {
130- // This should only happen on the last chunk
131- assert (desc.last == true );
132- r = num_chunk_frames - 1 ;
133- }
183+ // set default input location
184+ const float *feats = feats_in;
185+ int32_t ld = ldi;
134186
135- if (r < 0 ) {
136- // feats are located in stash from previous chunk
137- feats_src = feats_stash;
187+ // if l < 0 then feats come from the stash
188+ if (l_in < 0 ) {
189+ // input is from stash
190+ feats = feats_stash;
138191 ld = ldst;
139- r = r + ssize;
192+ l_in += ssize; // offset by stash size
140193 }
141194
142195 // for each column of input in parallel
143196 for (int c = tid; c < feat_dim; c += blockDim .x ) {
144197 // read feature from input row offset by column
145- float val = feats_src[r * ld + c];
198+ float val = feats[l_in * ld + c];
146199
147200 // write feature to output offset by splice index and column
148- feats_out[i * feat_dim + c] = val;
201+ feats_out[(i + left) * feat_dim + c] = val;
149202 }
150203 }
151204}
@@ -159,6 +212,7 @@ void splice_features_batched(int32_t num_chunk_frames, int32_t feat_dim,
159212 const LaneDesc *lanes, int32_t num_lanes) {
160213 int threads = (feat_dim + 31 ) / 32 * 32 ; // round up to the nearest warp size
161214 if (threads > 1024 ) threads = 1024 ; // Max block size is 1024 threads
215+
162216 dim3 blocks (num_chunk_frames, num_lanes);
163217
164218 splice_features_batched_kernel<<<blocks, threads>>> (
@@ -302,12 +356,15 @@ __global__ void batched_update_linear_and_quadratic_terms_kernel(
302356 linear = linear + lane * stridel;
303357 quadratic = quadratic + lane * strideq;
304358
305- // This is always zero. not 100% certain as why we don't need
306- // to account for earlier chunk. maybe Dan knows.
359+ // This is always zero because linear and quadratic terms are not
360+ // being carried forward. Thus we don't need to remove old prior
361+ // scale. Keeping the code below so that it logically matches
362+ // the CPU code in case someone is looking at this in the future.
307363 float old_num_frames = 0 ;
308364 // float old_num_frames = desc.current_frame;
309365 float new_num_frames = desc.current_frame + desc.num_chunk_frames ;
310366
367+ // in CPU code the frame counts are scaled by posterior scale
311368 new_num_frames *= posterior_scale;
312369 old_num_frames *= posterior_scale;
313370
@@ -458,9 +515,6 @@ __global__ void batched_sum_posteriors_kernel(
458515 int32_t stridep, float *gamma, int32_t strideg, float post_scale,
459516 const LaneDesc *lanes, int32_t num_lanes) {
460517 int32_t lane = blockIdx .y ;
461- LaneDesc desc = lanes[lane];
462-
463- int32_t num_rows = desc.num_chunk_frames ;
464518
465519 // offset input and output by lane
466520 posteriors = posteriors + lane * stridep;
@@ -471,7 +525,7 @@ __global__ void batched_sum_posteriors_kernel(
471525 col += blockDim .x * gridDim .x ) {
472526 // compute sum across rows for this column
473527 float sum = 0 .0f ;
474- for (int row = 0 ; row < num_rows ; row++) {
528+ for (int row = 0 ; row < chunk_size ; row++) {
475529 sum += posteriors[row * ldp + col];
476530 }
477531
@@ -509,7 +563,7 @@ __global__ void initialize_channels_kernel(int32_t num_gauss, int32_t feat_dim,
509563
510564 // initialize stashes to zero
511565 for (int i = threadIdx .y * blockDim .x + threadIdx .x ; i < num_gauss;
512- i += blockDim .x * gridDim .x ) {
566+ i += blockDim .y * blockDim .x ) {
513567 gamma[i] = 0 .0f ;
514568 }
515569
0 commit comments