@@ -15,10 +15,12 @@ limitations under the License. */
15
15
#pragma once
16
16
17
17
#include < glog/logging.h>
18
+
18
19
#include < algorithm>
19
20
#include < functional> // for multiplies
20
21
#include < iterator>
21
22
#include < vector>
23
+
22
24
#include " paddle/fluid/framework/eigen.h"
23
25
#include " paddle/fluid/framework/op_registry.h"
24
26
#include " paddle/fluid/framework/operator.h"
@@ -30,6 +32,7 @@ limitations under the License. */
30
32
#ifdef __NVCC__
31
33
#include < cuda.h>
32
34
#include < thrust/iterator/iterator_adaptor.h>
35
+
33
36
#include " paddle/fluid/platform/cuda_device_function.h"
34
37
#include " paddle/fluid/platform/cuda_primitives.h"
35
38
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024 ;
@@ -194,11 +197,11 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
194
197
}
195
198
196
199
#ifdef __NVCC__
197
- template <typename Functor, typename T>
200
+ template <typename Functor, typename T, typename OutType = T >
198
201
__global__ void CommonForwardBroadcastCUDAKernel (
199
202
const int *x_strides_array, const int *y_strides_array,
200
- const int *out_dims_array, const T *x, const T *y, T *out, int out_size ,
201
- int max_dim, Functor func, const bool is_xsize_larger) {
203
+ const int *out_dims_array, const T *x, const T *y, OutType *out,
204
+ int out_size, int max_dim, Functor func, const bool is_xsize_larger) {
202
205
for (int out_index = blockIdx.x * blockDim.x + threadIdx.x ;
203
206
out_index < out_size; out_index += blockDim.x * gridDim.x ) {
204
207
int x_index = 0 ;
@@ -220,7 +223,7 @@ __global__ void CommonForwardBroadcastCUDAKernel(
220
223
}
221
224
}
222
225
223
- template <typename Functor, typename T>
226
+ template <typename Functor, typename T, typename OutType = T >
224
227
void CommonForwardBroadcastCUDA (
225
228
const framework::Tensor *x, const framework::Tensor *y,
226
229
framework::Tensor *z, int *x_dims_array, int *y_dims_array,
@@ -230,7 +233,7 @@ void CommonForwardBroadcastCUDA(
230
233
auto cplace = platform::CPUPlace ();
231
234
const T *x_data = x->data <T>();
232
235
const T *y_data = y->data <T>();
233
- T *out_data = z->mutable_data <T >(ctx.GetPlace ());
236
+ OutType *out_data = z->mutable_data <OutType >(ctx.GetPlace ());
234
237
235
238
std::vector<int > x_strides_array (max_dim);
236
239
std::vector<int > y_strides_array (max_dim);
@@ -268,7 +271,7 @@ void CommonForwardBroadcastCUDA(
268
271
dim3 block_size = dim3 (PADDLE_CUDA_THREAD_SIZE, 1 );
269
272
270
273
CommonForwardBroadcastCUDAKernel<
271
- Functor, T><<<gird_size, block_size, 0 , ctx.stream ()>>>(
274
+ Functor, T, OutType ><<<gird_size, block_size, 0 , ctx.stream ()>>>(
272
275
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, x_data,
273
276
y_data, out_data, out_size, max_dim, func, is_xsize_larger);
274
277
}
@@ -1796,7 +1799,7 @@ void CommonElementwiseBroadcastForward(
1796
1799
1797
1800
if (platform::is_gpu_place (ctx.GetPlace ())) {
1798
1801
#ifdef __NVCC__
1799
- CommonForwardBroadcastCUDA<Functor, T>(
1802
+ CommonForwardBroadcastCUDA<Functor, T, OutType >(
1800
1803
x, y, z, x_dims_array.data (), y_dims_array.data (),
1801
1804
out_dims_array.data (), max_dim,
1802
1805
ctx.template device_context <platform::CUDADeviceContext>(), func,
0 commit comments