Skip to content

Commit 1829812

Browse files
authored
Fix sort kernel launch bug when nrows exceed gridDim.y limit (65535) (#3050)
1 parent 86bcf1e commit 1829812

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

candle-core/src/sort.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ mod cuda {
9494
let nrows = elem_count / ncols;
9595
let ncols_pad = next_power_of_2(ncols);
9696
let cfg = LaunchConfig {
97-
grid_dim: (1, nrows as u32, 1),
97+
grid_dim: (nrows as u32, 1, 1),
9898
block_dim: (ncols_pad as u32, 1, 1),
9999
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
100100
};

candle-kernels/src/sort.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ template<int order, typename T>
1515
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
1616
// bitonic sort
1717
int col = threadIdx.x;
18-
int row = blockIdx.y;
18+
int row = blockIdx.x;
1919

2020
if (col >= ncols_pad) {
2121
return;

0 commit comments

Comments
 (0)