Skip to content

Commit 2a07dfa

Browse files
authored
[ROCm] Fix indexing_backward_kernel perf (#2673)
cherry-pick of 8d42697
1 parent 7ea3967 commit 2a07dfa

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656
#endif
5757
}
5858

59-
#ifdef USE_ROCM
60-
#define SKIP_SORTED_INDICES 32
59+
#if 0
6160
template <typename scalar_t, int SZ>
6261
__global__ void indexing_backward_kernel(
6362
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -142,7 +141,10 @@ __global__ void indexing_backward_kernel(
142141
}
143142
}
144143
}
144+
#endif
145145

146+
#ifdef USE_ROCM
147+
#define SKIP_SORTED_INDICES 32
146148
template <typename scalar_t>
147149
__global__ void indexing_backward_kernel_stride_1(
148150
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -254,7 +256,8 @@ __global__ void indexing_backward_kernel_stride_1(
254256
}
255257
}
256258
}
257-
#else
259+
#endif
260+
258261
template <typename scalar_t, int SZ>
259262
__global__ void indexing_backward_kernel(
260263
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -333,6 +336,7 @@ __global__ void indexing_backward_kernel(
333336
}
334337
}
335338

339+
#ifndef USE_ROCM
336340
template <typename scalar_t>
337341
__global__ void indexing_backward_kernel_stride_1(
338342
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -784,7 +788,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784788
expandedValue.scalar_type(),
785789
"indexing_backward",
786790
AT_WRAP([&] {
787-
indexing_backward_kernel<scalar_t, UNROLL><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>(
791+
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
788792
sorted_indices.const_data_ptr<int64_t>(),
789793
orig_indices.const_data_ptr<int64_t>(),
790794
expandedValue.const_data_ptr<scalar_t>(),

0 commit comments

Comments
 (0)