Skip to content

Commit be3e276

Browse files
authored
Merge pull request #1009 from tianbingsz/paddle_func_mat
add paddle functions for Matrix ContextProjection APIs
2 parents 54a2b1f + ec6b13d commit be3e276

17 files changed

+1201
-800
lines changed

paddle/cuda/include/hl_sequence.h

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -48,78 +48,6 @@ extern void hl_max_sequence_forward(real* input,
4848
extern void hl_max_sequence_backward(
4949
real* outputGrad, int* index, real* inputGrad, int numSequences, int dim);
5050

51-
/**
52-
* @brief Context projection forward.
53-
*
54-
* @param[in] input input sequence.
55-
* @param[in] sequence sequence index.
56-
* @param[in] weightData padding data.
57-
* @param[out] output output sequence.
58-
* @param[in] numSequences number of sequences.
59-
* @param[in] inputDim input sequence dimension.
60-
* @param[in] contextLength context length.
61-
* @param[in] contextStart context start.
62-
* @param[in] beginPad number of extra timesteps added at the
63-
* beginning.
64-
* @param[in] isPadding trainable padding.
65-
*
66-
*/
67-
extern void hl_context_projection_forward(real* input,
68-
const int* sequence,
69-
real* weightData,
70-
real* output,
71-
int numSequences,
72-
int inputDim,
73-
int contextLength,
74-
int contextStart,
75-
int beginPad,
76-
bool isPadding);
77-
78-
/**
79-
* @brief Context projection backward data.
80-
*
81-
* @param[in] outputGrad output gradient.
82-
* @param[in] sequence sequence index.
83-
* @param[out] inputGrad input gradient.
84-
* @param[in] numSequences number of sequences.
85-
* @param[in] inputDim input sequence dimension.
86-
* @param[in] contextLength context length.
87-
* @param[in] contextStart context start.
88-
*
89-
*/
90-
extern void hl_context_projection_backward_data(real* outputGrad,
91-
const int* sequence,
92-
real* inputGrad,
93-
int numSequences,
94-
int inputDim,
95-
int contextLength,
96-
int contextStart);
97-
98-
/**
99-
* @brief Context projection backward weight.
100-
*
101-
* @param[in] outputGrad output gradient.
102-
* @param[in] sequence sequence index.
103-
* @param[out] weightGrad weight gradient.
104-
* @param[in] numSequences number of sequences.
105-
* @param[in] weightDim input sequence dimension.
106-
* @param[in] totalPad number of extra timesteps.
107-
* @param[in] contextLength context length.
108-
* @param[in] contextStart context start.
109-
* @param[in] beginPad number of extra timesteps added at the
110-
* beginning.
111-
*
112-
*/
113-
extern void hl_context_projection_backward_weight(real* outputGrad,
114-
const int* sequence,
115-
real* weightGrad,
116-
int numSequences,
117-
int weightDim,
118-
int totalPad,
119-
int contextLength,
120-
int contextStart,
121-
int beginPad);
122-
12351
/**
12452
* @brief Memory copy from sequence to batch.
12553
*

paddle/cuda/include/stub/hl_sequence_stub.h

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,6 @@ inline void hl_max_sequence_forward(real* input,
2727
inline void hl_max_sequence_backward(
2828
real* outputGrad, int* index, real* inputGrad, int numSequences, int dim) {}
2929

30-
inline void hl_context_projection_forward(real* input,
31-
const int* sequence,
32-
real* weightData,
33-
real* output,
34-
int numSequences,
35-
int inputDim,
36-
int contextLength,
37-
int contextStart,
38-
int beginPad,
39-
bool isPadding) {}
40-
41-
inline void hl_context_projection_backward_data(real* outputGrad,
42-
const int* sequence,
43-
real* inputGrad,
44-
int numSequences,
45-
int inputDim,
46-
int contextLength,
47-
int contextStart) {}
48-
49-
inline void hl_context_projection_backward_weight(real* outputGrad,
50-
const int* sequence,
51-
real* weightGrad,
52-
int numSequences,
53-
int weightDim,
54-
int totalPad,
55-
int contextLength,
56-
int contextStart,
57-
int beginPad) {}
58-
5930
inline void hl_sequence2batch_copy(real* batch,
6031
real* sequence,
6132
const int* batchIndex,

paddle/cuda/src/hl_cuda_sequence.cu

Lines changed: 0 additions & 252 deletions
Original file line numberDiff line numberDiff line change
@@ -90,258 +90,6 @@ void hl_max_sequence_backward(real* outputGrad,
9090
CHECK_SYNC("hl_max_sequence_backward failed");
9191
}
9292

93-
template <bool padding>
94-
__global__ void KeContextProjectionForward(real* input,
95-
const int* sequence,
96-
real* weightData,
97-
real* output,
98-
int inputDim,
99-
int contextLength,
100-
int contextStart,
101-
int beginPad) {
102-
int idx = threadIdx.x;
103-
int blockSize = blockDim.x;
104-
int sequenceId = blockIdx.x;
105-
int seqStart = sequence[sequenceId];
106-
int seqEnd = sequence[sequenceId+1];
107-
real value = 0;
108-
109-
int instances = seqEnd - seqStart + contextLength - 1;
110-
output += seqStart * inputDim * contextLength;
111-
input += seqStart * inputDim;
112-
for (int k = 0; k <= inputDim / blockSize; k++) {
113-
if (idx < inputDim) {
114-
for (int i = 0; i < instances; i++) {
115-
// i + contextStart;
116-
if ((i + contextStart) < 0) {
117-
if (padding) {
118-
value = weightData[i * inputDim + idx];
119-
} else {
120-
continue;
121-
}
122-
} else if ((i + contextStart) >= (seqEnd - seqStart)) {
123-
if (padding) {
124-
value =
125-
weightData[(beginPad + i + contextStart - (seqEnd - seqStart)) *
126-
inputDim + idx];
127-
} else {
128-
continue;
129-
}
130-
} else {
131-
value = input[(i + contextStart) * inputDim + idx];
132-
}
133-
134-
int outx = (i - contextLength) < 0 ? i : (contextLength - 1);
135-
int outy = (i - contextLength) < 0 ? 0 : (i - (contextLength - 1));
136-
real* output_r =
137-
output + outy * inputDim * contextLength + outx * inputDim;
138-
for (int j = outy; j < seqEnd - seqStart; j++) {
139-
output_r[idx] += value;
140-
if (j - outy == outx) break;
141-
output_r += (contextLength - 1) * inputDim;
142-
}
143-
}
144-
}
145-
idx += blockSize;
146-
}
147-
}
148-
149-
void hl_context_projection_forward(real* input,
150-
const int* sequence,
151-
real* weightData,
152-
real* output,
153-
int numSequences,
154-
int inputDim,
155-
int contextLength,
156-
int contextStart,
157-
int beginPad,
158-
bool isPadding) {
159-
CHECK_NOTNULL(input);
160-
CHECK_NOTNULL(sequence);
161-
CHECK_NOTNULL(output);
162-
CHECK(!isPadding || weightData);
163-
164-
int blockSize = 128;
165-
int blocksX = numSequences;
166-
int blocksY = 1;
167-
dim3 threads(blockSize, 1);
168-
dim3 grid(blocksX, blocksY);
169-
170-
if (isPadding) {
171-
KeContextProjectionForward<true><<< grid, threads, 0, STREAM_DEFAULT >>>
172-
(input, sequence, weightData, output, inputDim,
173-
contextLength, contextStart, beginPad);
174-
} else {
175-
KeContextProjectionForward<false><<< grid, threads, 0, STREAM_DEFAULT >>>
176-
(input, sequence, weightData, output, inputDim,
177-
contextLength, contextStart, beginPad);
178-
}
179-
CHECK_SYNC("hl_context_projection_forward failed");
180-
}
181-
182-
__global__ void KeContextProjectionBackwardData(real* outputGrad,
183-
const int* sequence,
184-
real* inputGrad,
185-
int inputDim,
186-
int contextLength,
187-
int contextStart) {
188-
int idx = threadIdx.x;
189-
int blockSize = blockDim.x;
190-
int sequenceId = blockIdx.x;
191-
int seqStart = sequence[sequenceId];
192-
int seqEnd = sequence[sequenceId+1];
193-
real value = 0;
194-
195-
int instances = seqEnd - seqStart + contextLength - 1;
196-
outputGrad += seqStart * inputDim * contextLength;
197-
inputGrad += seqStart * inputDim;
198-
for (int k = 0; k <= inputDim / blockSize; k++) {
199-
if (idx < inputDim) {
200-
for (int i = 0; i < instances; i++) {
201-
if ((i + contextStart) < 0) {
202-
continue;
203-
} else if ((i + contextStart) >= (seqEnd - seqStart)) {
204-
continue;
205-
} else {
206-
// value = 0;
207-
value = inputGrad[(i + contextStart) * inputDim + idx];
208-
}
209-
210-
int outx = (i - contextLength) < 0 ? i : (contextLength - 1);
211-
int outy = (i - contextLength) < 0 ? 0 : (i - (contextLength - 1));
212-
real* output_r =
213-
outputGrad + outy * inputDim * contextLength + outx * inputDim;
214-
for (int j = outy; j < seqEnd - seqStart; j++) {
215-
value += output_r[idx];
216-
if (j - outy == outx) break;
217-
output_r += (contextLength - 1) * inputDim;
218-
}
219-
inputGrad[(i + contextStart) * inputDim + idx] = value;
220-
}
221-
}
222-
idx += blockSize;
223-
}
224-
}
225-
226-
void hl_context_projection_backward_data(real* outputGrad,
227-
const int* sequence,
228-
real* inputGrad,
229-
int numSequences,
230-
int inputDim,
231-
int contextLength,
232-
int contextStart) {
233-
CHECK_NOTNULL(outputGrad);
234-
CHECK_NOTNULL(sequence);
235-
CHECK_NOTNULL(inputGrad);
236-
237-
int blockSize = 128;
238-
int blocksX = numSequences;
239-
int blocksY = 1;
240-
dim3 threads(blockSize, 1);
241-
dim3 grid(blocksX, blocksY);
242-
KeContextProjectionBackwardData<<< grid, threads, 0, STREAM_DEFAULT >>>
243-
(outputGrad, sequence, inputGrad, inputDim, contextLength, contextStart);
244-
CHECK_SYNC("hl_context_projection_backward_data failed");
245-
}
246-
247-
template<int THREADS_X, int THREADS_Y>
248-
__global__ void KeContextProjectionBackwardWeight(real* outputGrad,
249-
const int* sequence,
250-
real* weightGrad,
251-
int numSequences,
252-
int weightDim,
253-
int contextLength,
254-
int contextStart,
255-
int beginPad) {
256-
__shared__ real sum_s[THREADS_Y][THREADS_X];
257-
int padOfBlock = (weightDim + THREADS_X - 1) / THREADS_X;
258-
const int idx = threadIdx.x;
259-
const int idy = threadIdx.y;
260-
int padId = blockIdx.x / padOfBlock;
261-
int weightIdx = idx + THREADS_X * (blockIdx.x % padOfBlock);
262-
int instanceId;
263-
real value = 0;
264-
real* output_r;
265-
266-
sum_s[idy][idx] = 0.0f;
267-
if (weightIdx < weightDim) {
268-
for (int seqId = idy; seqId < numSequences; seqId += THREADS_Y) {
269-
int seqStart = sequence[seqId];
270-
int seqEnd = sequence[seqId+1];
271-
output_r = outputGrad + seqStart * weightDim * contextLength;
272-
273-
if (contextStart < 0) {
274-
if (padId + contextStart < 0) {
275-
instanceId = padId;
276-
} else {
277-
// beginPad > 0;
278-
instanceId = (padId - beginPad) + (seqEnd - seqStart) - contextStart;
279-
}
280-
} else {
281-
if (padId + (seqEnd - seqStart) < contextStart) {
282-
continue;
283-
} else {
284-
// beginPad == 0;
285-
instanceId = padId + (seqEnd - seqStart) - contextStart;
286-
}
287-
}
288-
289-
int outx = (instanceId - contextLength) < 0 ?
290-
instanceId : (contextLength - 1);
291-
int outy = (instanceId - contextLength) < 0 ?
292-
0 : (instanceId - (contextLength - 1));
293-
output_r += outy * weightDim * contextLength + outx * weightDim;
294-
for (int j = outy; j < seqEnd - seqStart; j++) {
295-
value += output_r[weightIdx];
296-
if (j - outy == outx) break;
297-
output_r += (contextLength - 1) * weightDim;
298-
}
299-
}
300-
sum_s[idy][idx] = value;
301-
}
302-
__syncthreads();
303-
304-
for (int stride = THREADS_Y/2; stride > 0; stride = stride/2) {
305-
if (idy < stride) {
306-
sum_s[idy][idx] += sum_s[idy + stride][idx];
307-
}
308-
__syncthreads();
309-
}
310-
__syncthreads();
311-
312-
if (weightIdx < weightDim) {
313-
if (idy == 0) {
314-
weightGrad[padId * weightDim + weightIdx] += sum_s[0][idx];
315-
}
316-
}
317-
}
318-
319-
void hl_context_projection_backward_weight(real* outputGrad,
320-
const int* sequence,
321-
real* weightGrad,
322-
int numSequences,
323-
int weightDim,
324-
int totalPad,
325-
int contextLength,
326-
int contextStart,
327-
int beginPad) {
328-
CHECK_NOTNULL(outputGrad);
329-
CHECK_NOTNULL(sequence);
330-
CHECK_NOTNULL(weightGrad);
331-
332-
int threadsX = 32;
333-
int threadsY = 32;
334-
int blocksX = totalPad * ((weightDim + threadsX - 1) / threadsX);
335-
dim3 threads(threadsX, threadsY);
336-
dim3 grid(blocksX, 1);
337-
338-
KeContextProjectionBackwardWeight<32, 32>
339-
<<< grid, threads, 0, STREAM_DEFAULT >>>
340-
(outputGrad, sequence, weightGrad, numSequences, weightDim,
341-
contextLength, contextStart, beginPad);
342-
CHECK_SYNC("hl_context_projection_backward_weight failed");
343-
}
344-
34593
template<int blockDimX, int blockDimY, int gridDimX, bool AddRow>
34694
__global__ void KeMatrixAddRows(real* output,
34795
real* table,

paddle/function/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ if(WITH_TESTING)
1717
# file(GLOB test_files . *OpTest.cpp)
1818
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
1919
add_simple_unittest(CrossMapNormalOpTest)
20+
add_unittest(ContextProjectionOpTest
21+
ContextProjectionOpTest.cpp
22+
../gserver/tests/TestUtil.cpp)
2023
endif()
2124
endif()
2225

0 commit comments

Comments
 (0)