@@ -14,82 +14,150 @@ limitations under the License. */
14
14
15
15
#include < thrust/execution_policy.h>
16
16
#include < thrust/sort.h>
17
+ #include " cub/cub.cuh"
17
18
#include " paddle/fluid/framework/op_registry.h"
18
19
#include " paddle/fluid/operators/argsort_op.h"
20
+ #include " paddle/fluid/operators/transpose_op.h"
19
21
#include " paddle/fluid/platform/cuda_device_function.h"
20
22
#include " paddle/fluid/platform/cuda_primitives.h"
21
23
24
+ // set cub base traits in order to handle float16
25
+ namespace cub {
26
+ template <>
27
+ struct NumericTraits <paddle::platform::float16>
28
+ : BaseTraits<FLOATING_POINT, true , false , uint16_t ,
29
+ paddle::platform::float16> {};
30
+ } // namespace cub
31
+
22
32
namespace paddle {
23
33
namespace operators {
24
34
25
35
using Tensor = framework::Tensor;
26
- using platform::PADDLE_CUDA_NUM_THREADS;
27
-
28
- const int kMaxRank = 9 ; // The max rank of a tensor allowed in Fluid
29
-
30
- __global__ void ComputeTargetIdx (const int64_t * in_dims, int dims_size,
31
- int axis, int64_t n, int64_t * trg_idx,
32
- int64_t * med_ids) {
33
- int64_t index = threadIdx .x + blockDim .x * blockIdx .x ;
34
- if (index < n) {
35
- int64_t shape_out_axis[kMaxRank - 1 ] = {0 };
36
- int64_t dims_out_axis[kMaxRank - 1 ] = {0 };
37
- int64_t tmp = index;
38
- int64_t pos_in_axis = 0 ;
39
- int64_t i = dims_size - 2 ;
40
- int64_t dim_axis = 0 ;
41
- for (int64_t j = dims_size - 1 ; j >= 0 ; --j) {
42
- int64_t dim = in_dims[j];
43
- if (j != axis) {
44
- shape_out_axis[i] = tmp % dim;
45
- dims_out_axis[i] = dim;
46
- i--;
47
- } else {
48
- dim_axis = dim;
49
- pos_in_axis = tmp % dim_axis;
50
- }
51
- tmp /= dim;
52
- }
53
- int64_t group = (dims_size > 1 ) ? shape_out_axis[0 ] : 0 ;
54
- for (int64_t j = 0 ; j < dims_size - 2 ; ++j) {
55
- group = group * dims_out_axis[j + 1 ] + shape_out_axis[j + 1 ];
56
- }
57
36
58
- int64_t traget_idx = group * dim_axis + pos_in_axis;
59
- trg_idx[index] = traget_idx;
60
- med_ids[traget_idx] = pos_in_axis;
61
- }
62
- }
37
+ // Iter for move to next row
38
+ struct SegmentOffsetIter {
39
+ EIGEN_DEVICE_FUNC
40
+ explicit SegmentOffsetIter (int num_cols) : num_cols_(num_cols) {}
63
41
64
- template <typename T>
65
- __global__ void PermuteInData (const T* in, const int64_t * trg_idx, int64_t n,
66
- T* med_out) {
67
- int index = threadIdx .x + blockDim .x * blockIdx .x ;
68
- if (index < n) {
69
- med_out[trg_idx[index]] = in[index];
42
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator ()(int idx) const {
43
+ return idx * num_cols_;
70
44
}
71
- }
45
+
46
+ int num_cols_;
47
+ };
72
48
73
49
template <typename T>
74
- __global__ void Sort (int64_t axis_dim, int64_t groups, T* med_out,
75
- int64_t * med_ids) {
76
- int index = threadIdx .x + blockDim .x * blockIdx .x ;
77
- if (index < groups) {
78
- thrust::sort_by_key (thrust::device, med_out + index * axis_dim,
79
- med_out + axis_dim * (1 + index),
80
- med_ids + index * axis_dim);
50
+ static __global__ void FillIndex (T* indices, T num_rows, T num_cols) {
51
+ int col_id = threadIdx .x ;
52
+ int row_id = blockIdx .x ;
53
+
54
+ for (T j = row_id; j < num_rows; j += gridDim .x ) {
55
+ for (T i = col_id; i < num_cols; i += blockDim .x ) {
56
+ indices[j * num_cols + i] = i;
57
+ }
81
58
}
82
59
}
83
60
84
- template <typename T>
85
- __global__ void PermuteMediateData (const T* med_out, const int64_t * med_ids,
86
- const int64_t * trg_idx, int64_t n, T* out,
87
- int64_t * indices) {
88
- int index = threadIdx .x + blockDim .x * blockIdx .x ;
89
- if (index < n) {
90
- out[index] = med_out[trg_idx[index]];
91
- indices[index] = med_ids[trg_idx[index]];
61
+ // Sort by flag descending, True: descending. False: Ascending.
62
+ // Default is false.
63
+ template <typename T, typename IndType>
64
+ void ArgFullSort (const platform::CUDADeviceContext& ctx, const Tensor* input,
65
+ Tensor* output, Tensor* indices, const IndType num_rows,
66
+ const IndType num_cols, const bool descending) {
67
+ auto cu_stream = ctx.stream ();
68
+
69
+ Tensor input_indices;
70
+
71
+ const std::vector<IndType> dims = {num_rows, num_cols};
72
+ auto dim = framework::make_ddim (dims);
73
+ input_indices.Resize (dim);
74
+ input_indices.mutable_data <IndType>(ctx.GetPlace ());
75
+
76
+ size_t temp_storage_bytes = -1 ;
77
+
78
+ auto ComputeBlockSize = [](IndType col) {
79
+ if (col > 512 )
80
+ return 1024 ;
81
+ else if (col > 256 && col <= 512 )
82
+ return 512 ;
83
+ else if (col > 128 && col <= 256 )
84
+ return 256 ;
85
+ else if (col > 64 && col <= 128 )
86
+ return 128 ;
87
+ else
88
+ return 64 ;
89
+ };
90
+
91
+ int block_size = ComputeBlockSize (num_cols);
92
+
93
+ int maxGridDimX = ctx.GetCUDAMaxGridDimSize ().x ;
94
+ // actually, int num_rows < max_grid_size
95
+ int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
96
+ // Init a index array
97
+ FillIndex<<<grid_size, block_size, 0 , cu_stream>>> (
98
+ input_indices.data <IndType>(), num_rows, num_cols);
99
+
100
+ T* sorted_out_ptr;
101
+ IndType* sorted_indices_ptr;
102
+
103
+ const T* inp = input->data <T>();
104
+ T* out = output->mutable_data <T>(ctx.GetPlace ());
105
+ IndType* ind = indices->mutable_data <IndType>(ctx.GetPlace ());
106
+
107
+ sorted_out_ptr = out;
108
+ sorted_indices_ptr = ind;
109
+
110
+ // create iter for counting input
111
+ cub::CountingInputIterator<IndType> counting_iter (0 );
112
+ // segment_offset is used for move to next row
113
+ cub::TransformInputIterator<IndType, SegmentOffsetIter,
114
+ cub::CountingInputIterator<IndType>>
115
+ segment_offsets_t (counting_iter, SegmentOffsetIter (num_cols));
116
+
117
+ cudaError_t err;
118
+ if (descending) {
119
+ err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
120
+ nullptr , temp_storage_bytes, inp, sorted_out_ptr,
121
+ input_indices.data <IndType>(), sorted_indices_ptr, num_cols * num_rows,
122
+ num_rows, segment_offsets_t , segment_offsets_t + 1 , 0 , sizeof (T) * 8 ,
123
+ cu_stream);
124
+ } else {
125
+ err = cub::DeviceSegmentedRadixSort::SortPairs (
126
+ nullptr , temp_storage_bytes, inp, sorted_out_ptr,
127
+ input_indices.data <IndType>(), sorted_indices_ptr, num_cols * num_rows,
128
+ num_rows, segment_offsets_t , segment_offsets_t + 1 , 0 , sizeof (T) * 8 ,
129
+ cu_stream);
130
+ }
131
+ PADDLE_ENFORCE_CUDA_SUCCESS (
132
+ err,
133
+ " ArgSortOP failed as could not launch "
134
+ " cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate"
135
+ " temp_storage_bytes, status:%s." ,
136
+ temp_storage_bytes, cudaGetErrorString (err));
137
+
138
+ Tensor temp_storage;
139
+ temp_storage.mutable_data <uint8_t >(ctx.GetPlace (), temp_storage_bytes);
140
+
141
+ if (descending) {
142
+ err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
143
+ temp_storage.data <uint8_t >(), temp_storage_bytes, inp, sorted_out_ptr,
144
+ input_indices.data <IndType>(), sorted_indices_ptr, num_cols * num_rows,
145
+ num_rows, segment_offsets_t , segment_offsets_t + 1 , 0 , sizeof (T) * 8 ,
146
+ cu_stream);
147
+ } else {
148
+ err = cub::DeviceSegmentedRadixSort::SortPairs (
149
+ temp_storage.data <uint8_t >(), temp_storage_bytes, inp, sorted_out_ptr,
150
+ input_indices.data <IndType>(), sorted_indices_ptr, num_cols * num_rows,
151
+ num_rows, segment_offsets_t , segment_offsets_t + 1 , 0 , sizeof (T) * 8 ,
152
+ cu_stream);
92
153
}
154
+
155
+ PADDLE_ENFORCE_CUDA_SUCCESS (
156
+ err,
157
+ " ArgSortOP failed as could not launch "
158
+ " cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
159
+ " temp_storage_bytes:%d status:%s." ,
160
+ temp_storage_bytes, cudaGetErrorString (err));
93
161
}
94
162
95
163
template <typename T>
@@ -100,51 +168,76 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
100
168
auto * output = ctx.Output <Tensor>(" Out" );
101
169
auto * indices = ctx.Output <Tensor>(" Indices" );
102
170
int axis = ctx.Attr <int >(" axis" );
171
+ bool descending = ctx.Attr <bool >(" descending" );
103
172
104
173
auto in_dims = input->dims ();
105
174
axis = (axis < 0 ) ? (in_dims.size () + axis) : axis;
106
175
107
- const T* in_data = input->data <T>();
108
- T* out_data = output->mutable_data <T>(ctx.GetPlace ());
109
- int64_t * ids_data = indices->mutable_data <int64_t >(ctx.GetPlace ());
110
-
111
176
int64_t numel = input->numel ();
112
177
int64_t groups = numel / in_dims[axis];
113
178
114
- std::vector<int64_t > in_dims_vec = vectorize (in_dims);
115
- thrust::device_vector<int64_t > in_dims_dev (in_dims_vec.begin (),
116
- in_dims_vec.end ());
117
- int64_t * in_dims_data = thrust::raw_pointer_cast (in_dims_dev.data ());
118
- // Mediate tensor for sorting data and indices
119
- Tensor mediate_output, mediate_indices;
120
- T* med_out_data =
121
- mediate_output.mutable_data <T>(input->dims (), ctx.GetPlace ());
122
- int64_t * med_ids_data =
123
- mediate_indices.mutable_data <int64_t >(in_dims, ctx.GetPlace ());
124
- // Target index of each element along the given axis in the mediate tensors
125
- Tensor trg_idx_t ;
126
- int64_t * trg_idx = trg_idx_t .mutable_data <int64_t >(in_dims, ctx.GetPlace ());
127
-
128
- auto stream = ctx.cuda_device_context ().stream ();
129
- const int num_threads = PADDLE_CUDA_NUM_THREADS;
130
-
131
- ComputeTargetIdx<<<(numel - 1 ) / num_threads + 1 , num_threads, 0 , stream>>> (
132
- in_dims_data, in_dims.size (), axis, numel, trg_idx, med_ids_data);
133
-
134
- PermuteInData<<<(numel - 1 ) / num_threads + 1 , num_threads, 0 , stream>>> (
135
- in_data, trg_idx, numel, med_out_data);
136
-
137
- Sort<<<(groups - 1 ) / num_threads + 1 , num_threads, 0 , stream>>> (
138
- in_dims[axis], groups, med_out_data, med_ids_data);
139
-
140
- PermuteMediateData<<<(numel - 1 ) / num_threads + 1 , num_threads, 0 ,
141
- stream>>> (med_out_data, med_ids_data, trg_idx, numel,
142
- out_data, ids_data);
179
+ // Special case for full sort, speedup ~190x.
180
+ if (axis == -1 || axis + 1 == in_dims.size ()) {
181
+ const int64_t input_height = framework::product (
182
+ framework::slice_ddim (in_dims, 0 , in_dims.size () - 1 ));
183
+ const int64_t input_width = in_dims[in_dims.size () - 1 ];
184
+ const auto & dev_ctx = ctx.cuda_device_context ();
185
+ ArgFullSort<T, int64_t >(dev_ctx, input, output, indices, input_height,
186
+ input_width, descending);
187
+ } else {
188
+ // if not full sort, do transpose first
189
+ std::vector<int > trans;
190
+ for (int i = 0 ; i < axis; i++) {
191
+ trans.push_back (i);
192
+ }
193
+ trans.push_back (in_dims.size () - 1 );
194
+ for (int i = axis + 1 ; i < in_dims.size () - 1 ; i++) {
195
+ trans.push_back (i);
196
+ }
197
+ trans.push_back (axis);
198
+ framework::DDim trans_dims (in_dims);
199
+ for (int i = 0 ; i < trans.size (); i++) {
200
+ trans_dims[i] = in_dims[trans[i]];
201
+ }
202
+
203
+ Tensor trans_inp;
204
+ T* trans_inp_data = trans_inp.mutable_data <T>(trans_dims, ctx.GetPlace ());
205
+ int ndims = trans.size ();
206
+ const auto & dev_ctx = ctx.cuda_device_context ();
207
+ // Do transpose
208
+ TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
209
+ &trans_inp, trans);
210
+
211
+ const int64_t input_height = framework::product (
212
+ framework::slice_ddim (trans_dims, 0 , trans_dims.size () - 1 ));
213
+ const int64_t input_width = trans_dims[trans_dims.size () - 1 ];
214
+
215
+ Tensor tmp_out;
216
+ tmp_out.mutable_data <T>(trans_dims, ctx.GetPlace ());
217
+ T* out_data = output->mutable_data <T>(ctx.GetPlace ());
218
+
219
+ Tensor tmp_indices;
220
+ // temp indices for sorting
221
+ tmp_indices.mutable_data <int64_t >(trans_dims, ctx.GetPlace ());
222
+ indices->mutable_data <int64_t >(ctx.GetPlace ());
223
+
224
+ ArgFullSort<T, int64_t >(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
225
+ input_height, input_width, descending);
226
+
227
+ TransCompute<platform::CUDADeviceContext, int64_t >(
228
+ ndims, dev_ctx, tmp_indices, indices, trans);
229
+ // transpose back
230
+ TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
231
+ output, trans);
232
+ return ;
233
+ }
143
234
}
144
235
};
145
236
146
237
} // namespace operators
147
238
} // namespace paddle
148
239
149
- REGISTER_OP_CUDA_KERNEL (argsort, paddle::operators::ArgsortOpCUDAKernel<float >,
150
- paddle::operators::ArgsortOpCUDAKernel<double >);
240
+ REGISTER_OP_CUDA_KERNEL (
241
+ argsort, paddle::operators::ArgsortOpCUDAKernel<float >,
242
+ paddle::operators::ArgsortOpCUDAKernel<double >,
243
+ paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
0 commit comments