@@ -12,18 +12,138 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include < thrust/device_vector.h>
16
+ #include " paddle/fluid/operators/math/math_function.h"
15
17
#include " paddle/fluid/operators/slice_op.h"
18
+ #include " paddle/fluid/platform/cuda_device_function.h"
19
+ #include " paddle/fluid/platform/cuda_primitives.h"
20
+ #include " paddle/fluid/platform/float16.h"
21
+
22
+ namespace paddle {
23
+ namespace operators {
24
+
25
+ using platform::PADDLE_CUDA_NUM_THREADS;
26
+
27
+ template <size_t D>
28
+ __global__ void Padding (const paddle::platform::float16* d_out,
29
+ const int * out_dims, const int * in_dims,
30
+ const int * offsets, int64_t n,
31
+ paddle::platform::float16* d_in) {
32
+ int64_t out_idx = threadIdx .x + blockDim .x * blockIdx .x ;
33
+ if (out_idx < n) {
34
+ int coords[D] = {0 };
35
+ for (int i = D - 1 ; i >= 0 ; --i) {
36
+ coords[i] = out_idx % out_dims[i];
37
+ out_idx /= out_dims[i];
38
+ coords[i] += offsets[i];
39
+ }
40
+
41
+ int64_t in_idx = 0 ;
42
+ for (int i = 0 ; i < D - 1 ; ++i) {
43
+ in_idx += coords[i] * in_dims[i + 1 ];
44
+ }
45
+ in_idx += coords[D - 1 ];
46
+
47
+ d_in[in_idx] = d_out[out_idx];
48
+ }
49
+ }
50
+
51
+ template <>
52
+ class SliceGradKernel <paddle::platform::CUDADeviceContext,
53
+ paddle::platform::float16>
54
+ : public framework::OpKernel<paddle::platform::float16> {
55
+ public:
56
+ void Compute (const framework::ExecutionContext& ctx) const override {
57
+ auto * d_out = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
58
+ auto * d_in = ctx.Output <framework::Tensor>(framework::GradVarName (" Input" ));
59
+ d_in->mutable_data <paddle::platform::float16>(ctx.GetPlace ());
60
+
61
+ auto out_dims = d_out->dims ();
62
+ auto in_dims = d_in->dims ();
63
+ int rank = out_dims.size ();
64
+ std::vector<int > offsets (rank, 0 );
65
+ auto axes = ctx.Attr <std::vector<int >>(" axes" );
66
+ auto starts = ctx.Attr <std::vector<int >>(" starts" );
67
+
68
+ for (size_t i = 0 ; i < starts.size (); ++i) {
69
+ if (starts[i] < 0 ) {
70
+ starts[i] += in_dims[axes[i]];
71
+ }
72
+ offsets[axes[i]] = std::max (starts[i], 0 );
73
+ }
74
+
75
+ math::SetConstant<paddle::platform::CUDADeviceContext,
76
+ paddle::platform::float16>
77
+ set_zero;
78
+ auto & dev_ctx =
79
+ ctx.template device_context <paddle::platform::CUDADeviceContext>();
80
+ set_zero (dev_ctx, d_in, static_cast <paddle::platform::float16>(0 ));
81
+
82
+ int64_t numel = d_out->numel ();
83
+ dim3 blocks ((numel - 1 ) / PADDLE_CUDA_NUM_THREADS + 1 , 1 , 1 );
84
+ dim3 threads (PADDLE_CUDA_NUM_THREADS, 1 , 1 );
85
+ auto stream = ctx.cuda_device_context ().stream ();
86
+
87
+ auto out_shape = framework::vectorize2int (out_dims);
88
+ thrust::device_vector<int > out_dims_vec (out_shape.begin (), out_shape.end ());
89
+ auto in_shape = framework::vectorize2int (in_dims);
90
+ thrust::device_vector<int > in_dims_vec (in_shape.begin (), in_shape.end ());
91
+ thrust::device_vector<int > offsets_vec (offsets.begin (), offsets.end ());
92
+ const int * out_dims_ptr = thrust::raw_pointer_cast (out_dims_vec.data ());
93
+ const int * in_dims_ptr = thrust::raw_pointer_cast (in_dims_vec.data ());
94
+ const int * offsets_ptr = thrust::raw_pointer_cast (offsets_vec.data ());
95
+
96
+ switch (rank) {
97
+ case 1 :
98
+ Padding<1 ><<<blocks, threads, 0 , stream>>> (
99
+ d_out->data <paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
100
+ offsets_ptr, numel, d_in->data <paddle::platform::float16>());
101
+ break ;
102
+ case 2 :
103
+ Padding<2 ><<<blocks, threads, 0 , stream>>> (
104
+ d_out->data <paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
105
+ offsets_ptr, numel, d_in->data <paddle::platform::float16>());
106
+ break ;
107
+ case 3 :
108
+ Padding<3 ><<<blocks, threads, 0 , stream>>> (
109
+ d_out->data <paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
110
+ offsets_ptr, numel, d_in->data <paddle::platform::float16>());
111
+ break ;
112
+ case 4 :
113
+ Padding<4 ><<<blocks, threads, 0 , stream>>> (
114
+ d_out->data <paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
115
+ offsets_ptr, numel, d_in->data <paddle::platform::float16>());
116
+ break ;
117
+ case 5 :
118
+ Padding<5 ><<<blocks, threads, 0 , stream>>> (
119
+ d_out->data <paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
120
+ offsets_ptr, numel, d_in->data <paddle::platform::float16>());
121
+ break ;
122
+ case 6 :
123
+ Padding<6 ><<<blocks, threads, 0 , stream>>> (
124
+ d_out->data <paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
125
+ offsets_ptr, numel, d_in->data <paddle::platform::float16>());
126
+ break ;
127
+ }
128
+ }
129
+ };
130
+
131
+ } // namespace operators
132
+ } // namespace paddle
16
133
17
134
namespace ops = paddle::operators;
135
+ namespace plat = paddle::platform;
18
136
REGISTER_OP_CUDA_KERNEL (
19
137
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float >,
20
138
ops::SliceKernel<paddle::platform::CUDADeviceContext, double >,
21
139
ops::SliceKernel<paddle::platform::CUDADeviceContext, int >,
22
- ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t >);
140
+ ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t >,
141
+ ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>);
23
142
24
143
REGISTER_OP_CUDA_KERNEL (
25
144
slice_grad,
26
145
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float >,
27
146
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double >,
28
147
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int >,
29
- ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t >);
148
+ ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t >,
149
+ ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
0 commit comments