Skip to content

Commit 66c18f4

Browse files
[cherry-pick] Improve argsort performance. (#21267) (#21442)
* Improve argsort performance. - Give 200000 data to compute argsort on v100, can speed up ~190x before opt cost: 0.53s after opt cost:0.0027s - Add fp16 support * Refine error message * Refine code * Add descending sort test=develop Signed-off-by: zhaoyuchen <[email protected]>
1 parent 735a2db commit 66c18f4

File tree

6 files changed

+395
-134
lines changed

6 files changed

+395
-134
lines changed

paddle/fluid/operators/argsort_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
7373
"When axis < 0, the actual axis will be the |axis|'th "
7474
"counting backwards. Default -1, the last dimension.")
7575
.SetDefault(-1);
76+
AddAttr<bool>(
77+
"descending",
78+
"(bool, default false) The descending attribute is a flag to tell"
79+
"algorithm how to sort the input data."
80+
"If descending is true, will sort by descending order,"
81+
"else if false, sort by ascending order. Default value is false.")
82+
.SetDefault(false);
7683
}
7784
};
7885

paddle/fluid/operators/argsort_op.cu

Lines changed: 186 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,82 +14,150 @@ limitations under the License. */
1414

1515
#include <thrust/execution_policy.h>
1616
#include <thrust/sort.h>
17+
#include "cub/cub.cuh"
1718
#include "paddle/fluid/framework/op_registry.h"
1819
#include "paddle/fluid/operators/argsort_op.h"
20+
#include "paddle/fluid/operators/transpose_op.h"
1921
#include "paddle/fluid/platform/cuda_device_function.h"
2022
#include "paddle/fluid/platform/cuda_primitives.h"
2123

24+
// set cub base traits in order to handle float16
25+
namespace cub {
26+
template <>
27+
struct NumericTraits<paddle::platform::float16>
28+
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
29+
paddle::platform::float16> {};
30+
} // namespace cub
31+
2232
namespace paddle {
2333
namespace operators {
2434

2535
using Tensor = framework::Tensor;
26-
using platform::PADDLE_CUDA_NUM_THREADS;
27-
28-
const int kMaxRank = 9; // The max rank of a tensor allowed in Fluid
29-
30-
__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
31-
int axis, int64_t n, int64_t* trg_idx,
32-
int64_t* med_ids) {
33-
int64_t index = threadIdx.x + blockDim.x * blockIdx.x;
34-
if (index < n) {
35-
int64_t shape_out_axis[kMaxRank - 1] = {0};
36-
int64_t dims_out_axis[kMaxRank - 1] = {0};
37-
int64_t tmp = index;
38-
int64_t pos_in_axis = 0;
39-
int64_t i = dims_size - 2;
40-
int64_t dim_axis = 0;
41-
for (int64_t j = dims_size - 1; j >= 0; --j) {
42-
int64_t dim = in_dims[j];
43-
if (j != axis) {
44-
shape_out_axis[i] = tmp % dim;
45-
dims_out_axis[i] = dim;
46-
i--;
47-
} else {
48-
dim_axis = dim;
49-
pos_in_axis = tmp % dim_axis;
50-
}
51-
tmp /= dim;
52-
}
53-
int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0;
54-
for (int64_t j = 0; j < dims_size - 2; ++j) {
55-
group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1];
56-
}
5736

58-
int64_t traget_idx = group * dim_axis + pos_in_axis;
59-
trg_idx[index] = traget_idx;
60-
med_ids[traget_idx] = pos_in_axis;
61-
}
62-
}
37+
// Iter for move to next row
38+
struct SegmentOffsetIter {
39+
EIGEN_DEVICE_FUNC
40+
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
6341

64-
template <typename T>
65-
__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n,
66-
T* med_out) {
67-
int index = threadIdx.x + blockDim.x * blockIdx.x;
68-
if (index < n) {
69-
med_out[trg_idx[index]] = in[index];
42+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
43+
return idx * num_cols_;
7044
}
71-
}
45+
46+
int num_cols_;
47+
};
7248

7349
template <typename T>
74-
__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out,
75-
int64_t* med_ids) {
76-
int index = threadIdx.x + blockDim.x * blockIdx.x;
77-
if (index < groups) {
78-
thrust::sort_by_key(thrust::device, med_out + index * axis_dim,
79-
med_out + axis_dim * (1 + index),
80-
med_ids + index * axis_dim);
50+
static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
51+
int col_id = threadIdx.x;
52+
int row_id = blockIdx.x;
53+
54+
for (T j = row_id; j < num_rows; j += gridDim.x) {
55+
for (T i = col_id; i < num_cols; i += blockDim.x) {
56+
indices[j * num_cols + i] = i;
57+
}
8158
}
8259
}
8360

84-
template <typename T>
85-
__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids,
86-
const int64_t* trg_idx, int64_t n, T* out,
87-
int64_t* indices) {
88-
int index = threadIdx.x + blockDim.x * blockIdx.x;
89-
if (index < n) {
90-
out[index] = med_out[trg_idx[index]];
91-
indices[index] = med_ids[trg_idx[index]];
61+
// Sort by flag descending, True: descending. False: Ascending.
62+
// Default is false.
63+
template <typename T, typename IndType>
64+
void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
65+
Tensor* output, Tensor* indices, const IndType num_rows,
66+
const IndType num_cols, const bool descending) {
67+
auto cu_stream = ctx.stream();
68+
69+
Tensor input_indices;
70+
71+
const std::vector<IndType> dims = {num_rows, num_cols};
72+
auto dim = framework::make_ddim(dims);
73+
input_indices.Resize(dim);
74+
input_indices.mutable_data<IndType>(ctx.GetPlace());
75+
76+
size_t temp_storage_bytes = -1;
77+
78+
auto ComputeBlockSize = [](IndType col) {
79+
if (col > 512)
80+
return 1024;
81+
else if (col > 256 && col <= 512)
82+
return 512;
83+
else if (col > 128 && col <= 256)
84+
return 256;
85+
else if (col > 64 && col <= 128)
86+
return 128;
87+
else
88+
return 64;
89+
};
90+
91+
int block_size = ComputeBlockSize(num_cols);
92+
93+
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
94+
// actually, int num_rows < max_grid_size
95+
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
96+
// Init a index array
97+
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
98+
input_indices.data<IndType>(), num_rows, num_cols);
99+
100+
T* sorted_out_ptr;
101+
IndType* sorted_indices_ptr;
102+
103+
const T* inp = input->data<T>();
104+
T* out = output->mutable_data<T>(ctx.GetPlace());
105+
IndType* ind = indices->mutable_data<IndType>(ctx.GetPlace());
106+
107+
sorted_out_ptr = out;
108+
sorted_indices_ptr = ind;
109+
110+
// create iter for counting input
111+
cub::CountingInputIterator<IndType> counting_iter(0);
112+
// segment_offset is used for move to next row
113+
cub::TransformInputIterator<IndType, SegmentOffsetIter,
114+
cub::CountingInputIterator<IndType>>
115+
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
116+
117+
cudaError_t err;
118+
if (descending) {
119+
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
120+
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
121+
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
122+
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
123+
cu_stream);
124+
} else {
125+
err = cub::DeviceSegmentedRadixSort::SortPairs(
126+
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
127+
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
128+
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
129+
cu_stream);
130+
}
131+
PADDLE_ENFORCE_CUDA_SUCCESS(
132+
err,
133+
"ArgSortOP failed as could not launch "
134+
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate"
135+
"temp_storage_bytes, status:%s.",
136+
temp_storage_bytes, cudaGetErrorString(err));
137+
138+
Tensor temp_storage;
139+
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
140+
141+
if (descending) {
142+
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
143+
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
144+
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
145+
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
146+
cu_stream);
147+
} else {
148+
err = cub::DeviceSegmentedRadixSort::SortPairs(
149+
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
150+
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
151+
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
152+
cu_stream);
92153
}
154+
155+
PADDLE_ENFORCE_CUDA_SUCCESS(
156+
err,
157+
"ArgSortOP failed as could not launch "
158+
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
159+
"temp_storage_bytes:%d status:%s.",
160+
temp_storage_bytes, cudaGetErrorString(err));
93161
}
94162

95163
template <typename T>
@@ -100,51 +168,76 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
100168
auto* output = ctx.Output<Tensor>("Out");
101169
auto* indices = ctx.Output<Tensor>("Indices");
102170
int axis = ctx.Attr<int>("axis");
171+
bool descending = ctx.Attr<bool>("descending");
103172

104173
auto in_dims = input->dims();
105174
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
106175

107-
const T* in_data = input->data<T>();
108-
T* out_data = output->mutable_data<T>(ctx.GetPlace());
109-
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
110-
111176
int64_t numel = input->numel();
112177
int64_t groups = numel / in_dims[axis];
113178

114-
std::vector<int64_t> in_dims_vec = vectorize(in_dims);
115-
thrust::device_vector<int64_t> in_dims_dev(in_dims_vec.begin(),
116-
in_dims_vec.end());
117-
int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data());
118-
// Mediate tensor for sorting data and indices
119-
Tensor mediate_output, mediate_indices;
120-
T* med_out_data =
121-
mediate_output.mutable_data<T>(input->dims(), ctx.GetPlace());
122-
int64_t* med_ids_data =
123-
mediate_indices.mutable_data<int64_t>(in_dims, ctx.GetPlace());
124-
// Target index of each element along the given axis in the mediate tensors
125-
Tensor trg_idx_t;
126-
int64_t* trg_idx = trg_idx_t.mutable_data<int64_t>(in_dims, ctx.GetPlace());
127-
128-
auto stream = ctx.cuda_device_context().stream();
129-
const int num_threads = PADDLE_CUDA_NUM_THREADS;
130-
131-
ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
132-
in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data);
133-
134-
PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
135-
in_data, trg_idx, numel, med_out_data);
136-
137-
Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>(
138-
in_dims[axis], groups, med_out_data, med_ids_data);
139-
140-
PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0,
141-
stream>>>(med_out_data, med_ids_data, trg_idx, numel,
142-
out_data, ids_data);
179+
// Special case for full sort, speedup ~190x.
180+
if (axis == -1 || axis + 1 == in_dims.size()) {
181+
const int64_t input_height = framework::product(
182+
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
183+
const int64_t input_width = in_dims[in_dims.size() - 1];
184+
const auto& dev_ctx = ctx.cuda_device_context();
185+
ArgFullSort<T, int64_t>(dev_ctx, input, output, indices, input_height,
186+
input_width, descending);
187+
} else {
188+
// if not full sort, do transpose first
189+
std::vector<int> trans;
190+
for (int i = 0; i < axis; i++) {
191+
trans.push_back(i);
192+
}
193+
trans.push_back(in_dims.size() - 1);
194+
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
195+
trans.push_back(i);
196+
}
197+
trans.push_back(axis);
198+
framework::DDim trans_dims(in_dims);
199+
for (int i = 0; i < trans.size(); i++) {
200+
trans_dims[i] = in_dims[trans[i]];
201+
}
202+
203+
Tensor trans_inp;
204+
T* trans_inp_data = trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
205+
int ndims = trans.size();
206+
const auto& dev_ctx = ctx.cuda_device_context();
207+
// Do transpose
208+
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
209+
&trans_inp, trans);
210+
211+
const int64_t input_height = framework::product(
212+
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
213+
const int64_t input_width = trans_dims[trans_dims.size() - 1];
214+
215+
Tensor tmp_out;
216+
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
217+
T* out_data = output->mutable_data<T>(ctx.GetPlace());
218+
219+
Tensor tmp_indices;
220+
// temp indices for sorting
221+
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
222+
indices->mutable_data<int64_t>(ctx.GetPlace());
223+
224+
ArgFullSort<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
225+
input_height, input_width, descending);
226+
227+
TransCompute<platform::CUDADeviceContext, int64_t>(
228+
ndims, dev_ctx, tmp_indices, indices, trans);
229+
// transpose back
230+
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
231+
output, trans);
232+
return;
233+
}
143234
}
144235
};
145236

146237
} // namespace operators
147238
} // namespace paddle
148239

149-
REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
150-
paddle::operators::ArgsortOpCUDAKernel<double>);
240+
REGISTER_OP_CUDA_KERNEL(
241+
argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
242+
paddle::operators::ArgsortOpCUDAKernel<double>,
243+
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);

0 commit comments

Comments
 (0)