@@ -55,8 +55,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5555#endif
5656}
5757
58- #ifdef USE_ROCM
59- #define SKIP_SORTED_INDICES 32
58+ #if 0
6059template <typename scalar_t, int SZ>
6160__global__ void indexing_backward_kernel(
6261 const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -141,7 +140,10 @@ __global__ void indexing_backward_kernel(
141140 }
142141 }
143142}
143+ #endif
144144
145+ #ifdef USE_ROCM
146+ #define SKIP_SORTED_INDICES 32
145147template <typename scalar_t >
146148__global__ void indexing_backward_kernel_stride_1 (
147149 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -253,7 +255,8 @@ __global__ void indexing_backward_kernel_stride_1(
253255 }
254256 }
255257}
256- #else
258+ #endif
259+
257260template <typename scalar_t , int SZ>
258261__global__ void indexing_backward_kernel (
259262 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -332,6 +335,7 @@ __global__ void indexing_backward_kernel(
332335 }
333336}
334337
338+ #ifndef USE_ROCM
335339template <typename scalar_t >
336340__global__ void indexing_backward_kernel_stride_1 (
337341 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -790,7 +794,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
790794 expandedValue.scalar_type (),
791795 " indexing_backward" ,
792796 AT_WRAP ([&] {
793- indexing_backward_kernel<scalar_t , UNROLL><<<KERNEL_GRID , block, KERNEL_SMEM , stream>>> (
797+ indexing_backward_kernel<scalar_t , UNROLL><<<grid , block, 0 , stream>>> (
794798 sorted_indices.const_data_ptr <int64_t >(),
795799 orig_indices.const_data_ptr <int64_t >(),
796800 expandedValue.const_data_ptr <scalar_t >(),
0 commit comments