Skip to content

Commit fe1f5d7

Browse files
authored
[ROCm] Fix non-stride-one backwards indexing performance (#2693)
cherry-pick of pytorch#164409
1 parent cbd27ae commit fe1f5d7

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656
#endif
5757
}
5858

59-
#if 0
59+
#ifdef USE_ROCM
60+
#define SKIP_SORTED_INDICES 32
6061
template <typename scalar_t, int SZ>
61-
__global__ void indexing_backward_kernel(
62+
__global__ void indexing_backward_kernel_many_indices(
6263
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
6364
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6465
using opmath_t = at::opmath_type<scalar_t>;
@@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
141142
}
142143
}
143144
}
144-
#endif
145145

146-
#ifdef USE_ROCM
147-
#define SKIP_SORTED_INDICES 32
148146
template <typename scalar_t>
149147
__global__ void indexing_backward_kernel_stride_1(
150148
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -784,6 +782,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784782
kBool,
785783
kBFloat16);
786784
} else {
785+
#ifdef USE_ROCM
786+
if (num_indices >= 200000)
787+
AT_DISPATCH_V2(
788+
expandedValue.scalar_type(),
789+
"indexing_backward_many_indices",
790+
AT_WRAP([&] {
791+
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
792+
sorted_indices.const_data_ptr<int64_t>(),
793+
orig_indices.const_data_ptr<int64_t>(),
794+
expandedValue.const_data_ptr<scalar_t>(),
795+
src_.mutable_data_ptr<scalar_t>(),
796+
num_indices,
797+
sliceSize,
798+
strideBefore,
799+
nElemBefore,
800+
accumulate);
801+
C10_CUDA_KERNEL_LAUNCH_CHECK();
802+
}),
803+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
804+
// AT_EXPAND(AT_FLOAT8_TYPES),
805+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
806+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
807+
kFloat8_e4m3fn,
808+
kFloat8_e5m2,
809+
kFloat8_e4m3fnuz,
810+
kFloat8_e5m2fnuz,
811+
kComplexHalf,
812+
kHalf,
813+
kBool,
814+
kBFloat16);
815+
else
816+
#endif
787817
AT_DISPATCH_V2(
788818
expandedValue.scalar_type(),
789819
"indexing_backward",

0 commit comments

Comments
 (0)