Skip to content

Commit 05621f7

Browse files
[cherry-pick] Add function comments and instructions to the Primitive API #36024
[cherry-pick] Add function comments and instructions to the Primitive API
1 parent 6b4f2fb commit 05621f7

File tree

2 files changed

+296
-198
lines changed

2 files changed

+296
-198
lines changed

paddle/fluid/operators/kernel_primitives/compute_primitives.h

Lines changed: 152 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ class MPTypeTrait<platform::float16> {
5454
};
5555

5656
/**
57-
* @brief will be used in BlockYReduce, get the index of reduce_num in shared
58-
* memory
57+
* @brief Will be used in BlockYReduce, get the index of reduce_num in shared
58+
* memory.
5959
*/
6060
__device__ __forceinline__ int SharedMemoryIndex(int index) {
6161
return (threadIdx.y + index) * blockDim.x + threadIdx.x;
@@ -83,7 +83,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
8383
*/
8484

8585
/**
86-
* @brief BlockXReduce reduce along blockDim.x
86+
* @brief BlockXReduce reduce along blockDim.x.
8787
*/
8888
template <typename T, typename ReduceOp>
8989
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
@@ -115,7 +115,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
115115
}
116116

117117
/**
118-
* @brief BlockYReduce reduce along blockDim.y
118+
* @brief BlockYReduce reduce along blockDim.y.
119119
*/
120120
template <typename T, typename ReduceOp>
121121
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
@@ -135,24 +135,33 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
135135
} // namespace details
136136

137137
/**
138-
* @brief unary function
139-
* @param
140-
* T: data type of in
141-
* OutT: data type of out
142-
* NX: the cols of in
143-
* NY: the rows of in
144-
* BlockSize: the config of this device
145-
* OpFunc: compute functor which have an operator() as following
146-
* template <typename T, typename OutT>
138+
* @brief Perform unary calculation according to OpFunc. Size of input and
139+
* output are the same.
140+
*
141+
* @template paraments
142+
* InT: Data type of in.
143+
* OutT: Data type of out.
144+
* NX: The number of data columns loaded by each thread.
145+
* NY: The number of data rows loaded by each thread.
146+
* BlockSize: Identifies the current device thread index method. For GPU,
147+
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
148+
* the index. Currently only GPU was supported.
149+
* OpFunc: Compute functor which has an operator() as following:
150+
* template <typename InT, typename OutT>
147151
* struct XxxFunctor {
148-
* HOSTDEVICE OutT operator()(const T& a) const {
152+
* HOSTDEVICE OutT operator()(const InT& a) const {
149153
* return ...;
150154
* }
151155
* };
156+
*
157+
* @param:
158+
* out: The register pointer of out, the size is NX * NY.
159+
* in: The register pointer of in, the size is NX * NY.
160+
* compute: Compute function which was declared like OpFunc<InT, OutT>().
152161
*/
153-
template <typename T, typename OutT, int NX, int NY, int BlockSize,
162+
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
154163
class OpFunc>
155-
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
164+
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in,
156165
OpFunc compute) {
157166
#pragma unroll
158167
for (int idx = 0; idx < NX * NY; idx++) {
@@ -161,25 +170,35 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
161170
}
162171

163172
/**
164-
* @brief binary function, in1 and in2 have same shape
165-
* @param
166-
* T: data type of in1, in2
167-
* OutT: data type of out
168-
* NX: the cols of in1, in2
169-
* NY: the rows of in1, in2
170-
* BlockSize: the config of this device
171-
* OpFunc: compute functor which have an operator() as following
172-
* template <typename T, typename OutT>
173+
* @brief Binary calculation according to OpFunc. Size of The input and output
174+
* are the same.
175+
*
176+
* @template paraments
177+
* InT: Data type of in1 and in2.
178+
* OutT: Data type of out.
179+
* NX: The number of data columns loaded by each thread.
180+
* NY: The number of data rows loaded by each thread.
181+
* BlockSize: Identifies the current device thread index method. For GPU,
182+
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
183+
* the index. Currently only GPU was supported.
184+
* OpFunc: Compute functor which has an operator() as following:
185+
* template <typename InT, typename OutT>
173186
* struct XxxFunctor {
174-
* HOSTDEVICE OutT operator()(const T& a, const T& b) const {
187+
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
175188
* return ...;
176189
* }
177190
* };
191+
*
192+
* @param:
193+
* out: The register pointer of out, the size is NX * NY.
194+
* in1: The register pointer of fist input, size is NX * NY.
195+
* in2: The register pointer of second input, size is NX * NY.
196+
* compute: Compute function which was declared like OpFunc<InT, OutT>().
178197
*/
179-
template <typename T, typename OutT, int NX, int NY, int BlockSize,
198+
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
180199
class OpFunc>
181-
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
182-
const T* in2,
200+
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1,
201+
const InT* in2,
183202
OpFunc compute) {
184203
#pragma unroll
185204
for (int idx = 0; idx < NX * NY; ++idx) {
@@ -188,25 +207,38 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
188207
}
189208

190209
/**
191-
* @brief ternary function, in1, in2 and in3 have same shape
192-
* @param
193-
* T: data type of in1, in2, in3
194-
* OutT: data type of out
195-
* NX: the cols of in1, in2
196-
* NY: the rows of in1, in2
197-
* BlockSize: the config of this device
198-
* OpFunc: compute functor which have an operator() as following
199-
* template <typename T, typename OutT>
210+
* @brief Ternary calculation according to OpFunc. Size of input and output
211+
* are the same.
212+
*
213+
* @template paraments
214+
* InT: Data type of in1 and in2.
215+
* OutT: Data type of out.
216+
* NX: The number of data columns loaded by each thread.
217+
* NY: The number of data rows loaded by each thread.
218+
* BlockSize: Identifies the current device thread index method. For GPU,
219+
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
220+
* the index. Currently only GPU was supported.
221+
* OpFunc: Compute functor which has an operator() as following
222+
* template <typename InT, typename OutT>
200223
* struct XxxFunctor {
201-
* HOSTDEVICE OutT operator()(const T& a, const T& b, const T& c) const {
224+
* HOSTDEVICE OutT operator()(const InT& a, const InT& b, const InT& c)
225+
* const {
202226
* return ...;
203227
* }
204228
* };
229+
*
230+
* @param
231+
* out: The register pointer of out, the size is NX * NY.
232+
* in1: The register pointer of fist input, size is NX * NY.
233+
* in2: The register pointer of second input, size is NX * NY.
234+
* in3: The register pointer of third input, size is NX * NY.
235+
* compute: Compute function which was declared like OpFunc<InT, OutT>().
205236
*/
206-
template <typename T, typename OutT, int NX, int NY, int BlockSize,
237+
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
207238
class OpFunc>
208-
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
209-
const T* in2, const T* in3,
239+
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const InT* in1,
240+
const InT* in2,
241+
const InT* in3,
210242
OpFunc compute) {
211243
#pragma unroll
212244
for (int idx = 0; idx < NX * NY; ++idx) {
@@ -215,27 +247,36 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
215247
}
216248

217249
/**
218-
* @brief a general function for elementwise computation, all inputs have
219-
* the same shape.
220-
* @param
221-
* T: data type of in1, in2, in3
222-
* OutT: data type of out
223-
* NX: the cols of in1, in2
224-
* NY: the rows of in1, in2
225-
* BlockSize: the config of this device
226-
* OpFunc: compute functor which have an operator() as following
227-
* template <typename T, typename OutT>
250+
* @brief Multivariate calculation according to OpFunc. Size of input and output
251+
* are the same.
252+
*
253+
* @template paraments
254+
* InT: Data type of in1 and in2.
255+
* OutT: Data type of out.
256+
* NX: The number of data columns loaded by each thread.
257+
* NY: The number of data rows loaded by each thread.
258+
* BlockSize: Identifies the current device thread index method. For GPU,
259+
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
260+
* the index. Currently only GPU was supported.
261+
* Arity: The size of ins
262+
* OpFunc: Compute functor which has an operator() as following:
263+
* template <typename InT, typename OutT>
228264
* struct XxxFunctor {
229-
* HOSTDEVICE OutT operator()(const T* args) const {
265+
* HOSTDEVICE OutT operator()(const InT* args) const {
230266
* return ...;
231267
* }
232268
* };
269+
*
270+
* @param
271+
* out: The register pointer of out, the size is NX * NY.
272+
* ins: An array of pointers consisting of multiple inputs.
273+
* compute: Compute function which was declared like OpFunc<InT, OutT>().
233274
*/
234-
template <typename T, typename OutT, int NX, int NY, int BlockSize, int Arity,
275+
template <typename InT, typename OutT, int NX, int NY, int BlockSize, int Arity,
235276
class OpFunc>
236-
__device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
277+
__device__ __forceinline__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY],
237278
OpFunc compute) {
238-
T args[Arity];
279+
InT args[Arity];
239280
#pragma unroll
240281
for (int idx = 0; idx < NX * NY; ++idx) {
241282
#pragma unroll
@@ -247,15 +288,31 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
247288
}
248289

249290
/**
250-
* @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
251-
* is [NY, NX], out's shape size is [NY, NX]
291+
* @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
292+
* different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
293+
* shape is [NY, NX].
294+
*
295+
* @template paraments
296+
* InT: Data type of in1 and in2.
297+
* OutT: Data type of out.
298+
* NX: The number of data columns loaded by each thread.
299+
* NY: The number of data rows loaded by each thread.
300+
* BlockSize: Identifies the current device thread index method. For GPU,
301+
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
302+
* the index. Currently only GPU was supported.
303+
* OpFunc: Compute functor which has an operator() as following
304+
* template <typename InT, typename OutT>
305+
* struct XxxFunctor {
306+
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
307+
* return ...;
308+
* }
309+
* };
310+
*
252311
* @param
253-
* T: data type of in1, in2
254-
* OutT: data type of out
255-
* NX: the cols of in1, in2
256-
* NY: the rows of in1, in2
257-
* BlockSize: the config of this device
258-
* OpFunc: compute functor eg: in1 + in2, in1 - in2
312+
* out: The register pointer of out, the size is NX * NY.
313+
* in1: The register pointer of fist input, size is NX * 1.
314+
* in2: The register pointer of second input, size is NX * NY.
315+
* compute: Compute function which was declared like OpFunc<InT, OutT>().
259316
*/
260317
template <typename T, typename OutT, int NX, int NY, int BlockSize,
261318
class OpFunc>
@@ -272,26 +329,37 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
272329
}
273330

274331
/**
275-
* @brief reduce function, in's shape size is [NX, NY].
276-
* If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
277-
* if ReduceMode == kGlobalMode then reduce between different threads, the
278-
* shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
279-
* split, BlockYReduce will be called. If reduce_last_dim is true and
280-
* reduce_num was split, BlockXReduce will be called
281-
* @typename
282-
* T: data type of in
283-
* NX: the cols of in
284-
* NY: the rows of in
285-
* BlockSize: the config of this device
286-
* OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h
287-
* @param:
288-
* reducer: reduce functor, eg: CustomSum<T>()
289-
* reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim =
290-
* true
332+
* @brief The Reduce provides collective methods for computing a parallel
333+
* reduction of items partitioned across a CUDA block and intra thread. When
334+
* ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
335+
* kGlobalMode, use shared memory to reduce between threads.
336+
*
337+
* @template paraments
338+
* T: The type of data.
339+
* NX: The number of data continuously loaded by each thread.
340+
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
341+
* BlockSize: Identifies the current device thread index method. For GPU,
342+
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
343+
* the index. Currently only GPU was supported.
344+
* ReduceFunctor: Compute functor which has an operator() as following
345+
* template <typename InT>
346+
* struct ReduceFunctor {
347+
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
348+
* return ...;
349+
* }
350+
* };
351+
* ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
352+
*
353+
* @param
354+
* out: The register pointer of out, the size is NX * NY.
355+
* in: The register pointer of in, the size is NX * NY.
356+
* reducer: Compute function which was declared like ReduceFunctor<InT>().
357+
* reduce_last_dim: if the last dim gets involved in reduction.
291358
*/
292-
template <typename T, int NX, int NY, int BlockSize, class OpFunc,
359+
template <typename T, int NX, int NY, int BlockSize, class ReduceFunctor,
293360
details::ReduceMode Mode>
294-
__device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
361+
__device__ __forceinline__ void Reduce(T* out, const T* in,
362+
ReduceFunctor reducer,
295363
bool reduce_last_dim) {
296364
int block_index = blockDim.y;
297365

@@ -302,15 +370,15 @@ __device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
302370
if (block_reduce_y) {
303371
#pragma unroll
304372
for (int i = 0; i < NY * NX; i++) { // reduce along blockdim.y
305-
out[i] = details::BlockYReduce<T, OpFunc>(out[i], reducer);
373+
out[i] = details::BlockYReduce<T, ReduceFunctor>(out[i], reducer);
306374
}
307375
}
308376

309377
// when last dimension need to be reduced
310378
if (reduce_last_dim) {
311379
#pragma unroll
312380
for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x
313-
out[i] = details::BlockXReduce<T, OpFunc>(out[i], reducer);
381+
out[i] = details::BlockXReduce<T, ReduceFunctor>(out[i], reducer);
314382
}
315383
}
316384
} else { // else kLocalMode

0 commit comments

Comments
 (0)