@@ -90,258 +90,6 @@ void hl_max_sequence_backward(real* outputGrad,
90
90
CHECK_SYNC (" hl_max_sequence_backward failed" );
91
91
}
92
92
93
- template <bool padding>
94
- __global__ void KeContextProjectionForward (real* input,
95
- const int * sequence,
96
- real* weightData,
97
- real* output,
98
- int inputDim,
99
- int contextLength,
100
- int contextStart,
101
- int beginPad) {
102
- int idx = threadIdx .x ;
103
- int blockSize = blockDim .x ;
104
- int sequenceId = blockIdx .x ;
105
- int seqStart = sequence[sequenceId];
106
- int seqEnd = sequence[sequenceId+1 ];
107
- real value = 0 ;
108
-
109
- int instances = seqEnd - seqStart + contextLength - 1 ;
110
- output += seqStart * inputDim * contextLength;
111
- input += seqStart * inputDim;
112
- for (int k = 0 ; k <= inputDim / blockSize; k++) {
113
- if (idx < inputDim) {
114
- for (int i = 0 ; i < instances; i++) {
115
- // i + contextStart;
116
- if ((i + contextStart) < 0 ) {
117
- if (padding) {
118
- value = weightData[i * inputDim + idx];
119
- } else {
120
- continue ;
121
- }
122
- } else if ((i + contextStart) >= (seqEnd - seqStart)) {
123
- if (padding) {
124
- value =
125
- weightData[(beginPad + i + contextStart - (seqEnd - seqStart)) *
126
- inputDim + idx];
127
- } else {
128
- continue ;
129
- }
130
- } else {
131
- value = input[(i + contextStart) * inputDim + idx];
132
- }
133
-
134
- int outx = (i - contextLength) < 0 ? i : (contextLength - 1 );
135
- int outy = (i - contextLength) < 0 ? 0 : (i - (contextLength - 1 ));
136
- real* output_r =
137
- output + outy * inputDim * contextLength + outx * inputDim;
138
- for (int j = outy; j < seqEnd - seqStart; j++) {
139
- output_r[idx] += value;
140
- if (j - outy == outx) break ;
141
- output_r += (contextLength - 1 ) * inputDim;
142
- }
143
- }
144
- }
145
- idx += blockSize;
146
- }
147
- }
148
-
149
- void hl_context_projection_forward (real* input,
150
- const int * sequence,
151
- real* weightData,
152
- real* output,
153
- int numSequences,
154
- int inputDim,
155
- int contextLength,
156
- int contextStart,
157
- int beginPad,
158
- bool isPadding) {
159
- CHECK_NOTNULL (input);
160
- CHECK_NOTNULL (sequence);
161
- CHECK_NOTNULL (output);
162
- CHECK (!isPadding || weightData);
163
-
164
- int blockSize = 128 ;
165
- int blocksX = numSequences;
166
- int blocksY = 1 ;
167
- dim3 threads (blockSize, 1 );
168
- dim3 grid (blocksX, blocksY);
169
-
170
- if (isPadding) {
171
- KeContextProjectionForward<true ><<< grid, threads, 0 , STREAM_DEFAULT >>>
172
- (input, sequence, weightData, output, inputDim,
173
- contextLength, contextStart, beginPad);
174
- } else {
175
- KeContextProjectionForward<false ><<< grid, threads, 0 , STREAM_DEFAULT >>>
176
- (input, sequence, weightData, output, inputDim,
177
- contextLength, contextStart, beginPad);
178
- }
179
- CHECK_SYNC (" hl_context_projection_forward failed" );
180
- }
181
-
182
- __global__ void KeContextProjectionBackwardData (real* outputGrad,
183
- const int * sequence,
184
- real* inputGrad,
185
- int inputDim,
186
- int contextLength,
187
- int contextStart) {
188
- int idx = threadIdx .x ;
189
- int blockSize = blockDim .x ;
190
- int sequenceId = blockIdx .x ;
191
- int seqStart = sequence[sequenceId];
192
- int seqEnd = sequence[sequenceId+1 ];
193
- real value = 0 ;
194
-
195
- int instances = seqEnd - seqStart + contextLength - 1 ;
196
- outputGrad += seqStart * inputDim * contextLength;
197
- inputGrad += seqStart * inputDim;
198
- for (int k = 0 ; k <= inputDim / blockSize; k++) {
199
- if (idx < inputDim) {
200
- for (int i = 0 ; i < instances; i++) {
201
- if ((i + contextStart) < 0 ) {
202
- continue ;
203
- } else if ((i + contextStart) >= (seqEnd - seqStart)) {
204
- continue ;
205
- } else {
206
- // value = 0;
207
- value = inputGrad[(i + contextStart) * inputDim + idx];
208
- }
209
-
210
- int outx = (i - contextLength) < 0 ? i : (contextLength - 1 );
211
- int outy = (i - contextLength) < 0 ? 0 : (i - (contextLength - 1 ));
212
- real* output_r =
213
- outputGrad + outy * inputDim * contextLength + outx * inputDim;
214
- for (int j = outy; j < seqEnd - seqStart; j++) {
215
- value += output_r[idx];
216
- if (j - outy == outx) break ;
217
- output_r += (contextLength - 1 ) * inputDim;
218
- }
219
- inputGrad[(i + contextStart) * inputDim + idx] = value;
220
- }
221
- }
222
- idx += blockSize;
223
- }
224
- }
225
-
226
- void hl_context_projection_backward_data (real* outputGrad,
227
- const int * sequence,
228
- real* inputGrad,
229
- int numSequences,
230
- int inputDim,
231
- int contextLength,
232
- int contextStart) {
233
- CHECK_NOTNULL (outputGrad);
234
- CHECK_NOTNULL (sequence);
235
- CHECK_NOTNULL (inputGrad);
236
-
237
- int blockSize = 128 ;
238
- int blocksX = numSequences;
239
- int blocksY = 1 ;
240
- dim3 threads (blockSize, 1 );
241
- dim3 grid (blocksX, blocksY);
242
- KeContextProjectionBackwardData<<< grid, threads, 0 , STREAM_DEFAULT >>>
243
- (outputGrad, sequence, inputGrad, inputDim, contextLength, contextStart);
244
- CHECK_SYNC (" hl_context_projection_backward_data failed" );
245
- }
246
-
247
- template <int THREADS_X, int THREADS_Y>
248
- __global__ void KeContextProjectionBackwardWeight (real* outputGrad,
249
- const int * sequence,
250
- real* weightGrad,
251
- int numSequences,
252
- int weightDim,
253
- int contextLength,
254
- int contextStart,
255
- int beginPad) {
256
- __shared__ real sum_s[THREADS_Y][THREADS_X];
257
- int padOfBlock = (weightDim + THREADS_X - 1 ) / THREADS_X;
258
- const int idx = threadIdx .x ;
259
- const int idy = threadIdx .y ;
260
- int padId = blockIdx .x / padOfBlock;
261
- int weightIdx = idx + THREADS_X * (blockIdx .x % padOfBlock);
262
- int instanceId;
263
- real value = 0 ;
264
- real* output_r;
265
-
266
- sum_s[idy][idx] = 0 .0f ;
267
- if (weightIdx < weightDim) {
268
- for (int seqId = idy; seqId < numSequences; seqId += THREADS_Y) {
269
- int seqStart = sequence[seqId];
270
- int seqEnd = sequence[seqId+1 ];
271
- output_r = outputGrad + seqStart * weightDim * contextLength;
272
-
273
- if (contextStart < 0 ) {
274
- if (padId + contextStart < 0 ) {
275
- instanceId = padId;
276
- } else {
277
- // beginPad > 0;
278
- instanceId = (padId - beginPad) + (seqEnd - seqStart) - contextStart;
279
- }
280
- } else {
281
- if (padId + (seqEnd - seqStart) < contextStart) {
282
- continue ;
283
- } else {
284
- // beginPad == 0;
285
- instanceId = padId + (seqEnd - seqStart) - contextStart;
286
- }
287
- }
288
-
289
- int outx = (instanceId - contextLength) < 0 ?
290
- instanceId : (contextLength - 1 );
291
- int outy = (instanceId - contextLength) < 0 ?
292
- 0 : (instanceId - (contextLength - 1 ));
293
- output_r += outy * weightDim * contextLength + outx * weightDim;
294
- for (int j = outy; j < seqEnd - seqStart; j++) {
295
- value += output_r[weightIdx];
296
- if (j - outy == outx) break ;
297
- output_r += (contextLength - 1 ) * weightDim;
298
- }
299
- }
300
- sum_s[idy][idx] = value;
301
- }
302
- __syncthreads ();
303
-
304
- for (int stride = THREADS_Y/2 ; stride > 0 ; stride = stride/2 ) {
305
- if (idy < stride) {
306
- sum_s[idy][idx] += sum_s[idy + stride][idx];
307
- }
308
- __syncthreads ();
309
- }
310
- __syncthreads ();
311
-
312
- if (weightIdx < weightDim) {
313
- if (idy == 0 ) {
314
- weightGrad[padId * weightDim + weightIdx] += sum_s[0 ][idx];
315
- }
316
- }
317
- }
318
-
319
- void hl_context_projection_backward_weight (real* outputGrad,
320
- const int * sequence,
321
- real* weightGrad,
322
- int numSequences,
323
- int weightDim,
324
- int totalPad,
325
- int contextLength,
326
- int contextStart,
327
- int beginPad) {
328
- CHECK_NOTNULL (outputGrad);
329
- CHECK_NOTNULL (sequence);
330
- CHECK_NOTNULL (weightGrad);
331
-
332
- int threadsX = 32 ;
333
- int threadsY = 32 ;
334
- int blocksX = totalPad * ((weightDim + threadsX - 1 ) / threadsX);
335
- dim3 threads (threadsX, threadsY);
336
- dim3 grid (blocksX, 1 );
337
-
338
- KeContextProjectionBackwardWeight<32 , 32 >
339
- <<< grid, threads, 0 , STREAM_DEFAULT >>>
340
- (outputGrad, sequence, weightGrad, numSequences, weightDim,
341
- contextLength, contextStart, beginPad);
342
- CHECK_SYNC (" hl_context_projection_backward_weight failed" );
343
- }
344
-
345
93
template <int blockDimX, int blockDimY, int gridDimX, bool AddRow>
346
94
__global__ void KeMatrixAddRows (real* output,
347
95
real* table,
0 commit comments