@@ -56,8 +56,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656#endif
5757}
5858
59- #ifdef USE_ROCM
60- #define SKIP_SORTED_INDICES 32
59+ #if 0
6160template <typename scalar_t, int SZ>
6261__global__ void indexing_backward_kernel(
6362 const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -142,7 +141,10 @@ __global__ void indexing_backward_kernel(
142141 }
143142 }
144143}
144+ #endif
145145
146+ #ifdef USE_ROCM
147+ #define SKIP_SORTED_INDICES 32
146148template <typename scalar_t >
147149__global__ void indexing_backward_kernel_stride_1 (
148150 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -254,7 +256,8 @@ __global__ void indexing_backward_kernel_stride_1(
254256 }
255257 }
256258}
257- #else
259+ #endif
260+
258261template <typename scalar_t , int SZ>
259262__global__ void indexing_backward_kernel (
260263 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -333,6 +336,7 @@ __global__ void indexing_backward_kernel(
333336 }
334337}
335338
339+ #ifndef USE_ROCM
336340template <typename scalar_t >
337341__global__ void indexing_backward_kernel_stride_1 (
338342 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -784,7 +788,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784788 expandedValue.scalar_type (),
785789 " indexing_backward" ,
786790 AT_WRAP ([&] {
787- indexing_backward_kernel<scalar_t , UNROLL><<<KERNEL_GRID , block, KERNEL_SMEM , stream>>> (
791+ indexing_backward_kernel<scalar_t , UNROLL><<<grid , block, 0 , stream>>> (
788792 sorted_indices.const_data_ptr <int64_t >(),
789793 orig_indices.const_data_ptr <int64_t >(),
790794 expandedValue.const_data_ptr <scalar_t >(),
0 commit comments