From 4793a4d29c80989d7d5481ae40f19d2fbb88d75b Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Tue, 23 Sep 2025 14:06:29 -0700 Subject: [PATCH] [ROCm] Fix indexing_backward_kernel perf cherry-pick of 8d42697 --- aten/src/ATen/native/cuda/Indexing.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 02feb55cb69d6..e49fffc2effcd 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -56,8 +56,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() { #endif } -#ifdef USE_ROCM -#define SKIP_SORTED_INDICES 32 +#if 0 template __global__ void indexing_backward_kernel( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -142,7 +141,10 @@ __global__ void indexing_backward_kernel( } } } +#endif +#ifdef USE_ROCM +#define SKIP_SORTED_INDICES 32 template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -254,7 +256,8 @@ __global__ void indexing_backward_kernel_stride_1( } } } -#else +#endif + template __global__ void indexing_backward_kernel( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -333,6 +336,7 @@ __global__ void indexing_backward_kernel( } } +#ifndef USE_ROCM template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -784,7 +788,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( + indexing_backward_kernel<<>>( sorted_indices.const_data_ptr(), orig_indices.const_data_ptr(), expandedValue.const_data_ptr(),