Skip to content

Commit 975f61d

Browse files
authored
[ROCm] Adjust grid size for non-unit stride backwards indexing (#2714)
cherry-pick of pytorch@01a2812
1 parent 018e50b commit 975f61d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
716716
dim3 block(warp_size, indices_per_block);
717717

718718
#ifdef USE_ROCM
719+
dim3 new_grid_many_indices(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)),
720+
grid.y == 1 ? std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y,
721+
grid.z);
719722
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
720723
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
721724
#define KERNEL_GRID new_grid
@@ -794,7 +797,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
794797
expandedValue.scalar_type(),
795798
"indexing_backward_many_indices",
796799
AT_WRAP([&] {
797-
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
800+
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid_many_indices, block, smem_dups_size, stream>>>(
798801
sorted_indices.const_data_ptr<int64_t>(),
799802
orig_indices.const_data_ptr<int64_t>(),
800803
expandedValue.const_data_ptr<scalar_t>(),

0 commit comments

Comments
 (0)