@@ -716,6 +716,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
716716 dim3 block (warp_size, indices_per_block);
717717
718718#ifdef USE_ROCM
719+ dim3 new_grid_many_indices (ceil_div (num_indices, (int64_t ) (indices_per_block * warp_size)),
720+ grid.y == 1 ? std::min<int >(at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ], ceil_div (sliceSize, (int64_t ) (warp_size))) : grid.y ,
721+ grid.z );
719722 dim3 new_grid (ceil_div (num_indices, (int64_t ) (indices_per_block * warp_size)), grid.y , grid.z );
720723 size_t smem_dups_size = indices_per_block * warp_size * sizeof (int64_t );
721724#define KERNEL_GRID new_grid
@@ -794,7 +797,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
794797 expandedValue.scalar_type (),
795798 " indexing_backward_many_indices" ,
796799 AT_WRAP ([&] {
797- indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid , block, smem_dups_size, stream>>> (
800+ indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid_many_indices , block, smem_dups_size, stream>>> (
798801 sorted_indices.const_data_ptr <int64_t >(),
799802 orig_indices.const_data_ptr <int64_t >(),
800803 expandedValue.const_data_ptr <scalar_t >(),
0 commit comments