Skip to content

Commit 4ba0bd5

Browse files
doru1004AMD AMD
authored andcommitted
[ROCm] Fix non-stride-one backwards indexing performance (#2693)
cherry-pick of pytorch#164409
1 parent c2114ee commit 4ba0bd5

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5555
#endif
5656
}
5757

58-
#if 0
58+
#ifdef USE_ROCM
59+
#define SKIP_SORTED_INDICES 32
5960
template <typename scalar_t, int SZ>
60-
__global__ void indexing_backward_kernel(
61+
__global__ void indexing_backward_kernel_many_indices(
6162
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
6263
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6364
using opmath_t = at::opmath_type<scalar_t>;
@@ -140,10 +141,7 @@ __global__ void indexing_backward_kernel(
140141
}
141142
}
142143
}
143-
#endif
144144

145-
#ifdef USE_ROCM
146-
#define SKIP_SORTED_INDICES 32
147145
template <typename scalar_t>
148146
__global__ void indexing_backward_kernel_stride_1(
149147
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -790,6 +788,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
790788
kBool,
791789
kBFloat16);
792790
} else {
791+
#ifdef USE_ROCM
792+
if (num_indices >= 200000)
793+
AT_DISPATCH_V2(
794+
expandedValue.scalar_type(),
795+
"indexing_backward_many_indices",
796+
AT_WRAP([&] {
797+
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
798+
sorted_indices.const_data_ptr<int64_t>(),
799+
orig_indices.const_data_ptr<int64_t>(),
800+
expandedValue.const_data_ptr<scalar_t>(),
801+
src_.mutable_data_ptr<scalar_t>(),
802+
num_indices,
803+
sliceSize,
804+
strideBefore,
805+
nElemBefore,
806+
accumulate);
807+
C10_CUDA_KERNEL_LAUNCH_CHECK();
808+
}),
809+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
810+
// AT_EXPAND(AT_FLOAT8_TYPES),
811+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
812+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
813+
kFloat8_e4m3fn,
814+
kFloat8_e5m2,
815+
kFloat8_e4m3fnuz,
816+
kFloat8_e5m2fnuz,
817+
kComplexHalf,
818+
kHalf,
819+
kBool,
820+
kBFloat16);
821+
else
822+
#endif
793823
AT_DISPATCH_V2(
794824
expandedValue.scalar_type(),
795825
"indexing_backward",

0 commit comments

Comments
 (0)