Skip to content

Commit 8d42697

Browse files
authored
[ROCm] Fix indexing_backward_kernel perf (#2650)
* Revert of a1cb67b Fixes #SWDEV-552103
1 parent c02c48a commit 8d42697

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5555
#endif
5656
}
5757

58-
#ifdef USE_ROCM
59-
#define SKIP_SORTED_INDICES 32
58+
#if 0
6059
template <typename scalar_t, int SZ>
6160
__global__ void indexing_backward_kernel(
6261
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -141,7 +140,10 @@ __global__ void indexing_backward_kernel(
141140
}
142141
}
143142
}
143+
#endif
144144

145+
#ifdef USE_ROCM
146+
#define SKIP_SORTED_INDICES 32
145147
template <typename scalar_t>
146148
__global__ void indexing_backward_kernel_stride_1(
147149
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -253,7 +255,8 @@ __global__ void indexing_backward_kernel_stride_1(
253255
}
254256
}
255257
}
256-
#else
258+
#endif
259+
257260
template <typename scalar_t, int SZ>
258261
__global__ void indexing_backward_kernel(
259262
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -332,6 +335,7 @@ __global__ void indexing_backward_kernel(
332335
}
333336
}
334337

338+
#ifndef USE_ROCM
335339
template <typename scalar_t>
336340
__global__ void indexing_backward_kernel_stride_1(
337341
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -790,7 +794,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
790794
expandedValue.scalar_type(),
791795
"indexing_backward",
792796
AT_WRAP([&] {
793-
indexing_backward_kernel<scalar_t, UNROLL><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>(
797+
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
794798
sorted_indices.const_data_ptr<int64_t>(),
795799
orig_indices.const_data_ptr<int64_t>(),
796800
expandedValue.const_data_ptr<scalar_t>(),

0 commit comments

Comments
 (0)