File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -198,6 +198,13 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
198
198
const int64_t num_rows,
199
199
const int64_t num_cols,
200
200
const bool descending) {
201
+ PADDLE_ENFORCE_LE (num_cols,
202
+ std::numeric_limits<int >::max (),
203
+ ::common::errors::PreconditionNotMet (
204
+ " The dimension being sorted should be less than "
205
+ " 2^31, but got %lld. Please check the input tensor. " ,
206
+ num_cols));
207
+
201
208
auto cu_stream = dev_ctx.stream ();
202
209
auto ComputeBlockSize = [](IndType col) {
203
210
if (col > 512 )
@@ -228,8 +235,14 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
228
235
const int64_t total_elements = num_cols * num_rows;
229
236
const int64_t segment_size = num_cols;
230
237
const int64_t element_per_call = std::min (max_elements, total_elements);
238
+
239
+ // make sure element_per_call >= segment_size
240
+ const int64_t adjusted_elements_per_call =
241
+ std::max (max_elements, segment_size);
242
+
231
243
// make sure batch size is the multiple of segment_size
232
- const int64_t batch_size = (element_per_call / segment_size) * segment_size;
244
+ const int64_t batch_size =
245
+ (adjusted_elements_per_call / segment_size) * segment_size;
233
246
int64_t offset = 0 ;
234
247
DenseTensor input_indices;
235
248
You can’t perform that action at this time.
0 commit comments