Skip to content

Commit 71ef8e9

Browse files
committed
Fix non-stride-one backwards indexing performance
1 parent 0b82d9a commit 71ef8e9

File tree

1 file changed

+65
-5
lines changed

1 file changed

+65
-5
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656
#endif
5757
}
5858

59-
#if 0
59+
#ifdef USE_ROCM
60+
#define SKIP_SORTED_INDICES 32
6061
template <typename scalar_t, int SZ>
61-
__global__ void indexing_backward_kernel(
62+
__global__ void indexing_backward_kernel_many_indices(
6263
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
6364
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6465
using opmath_t = at::opmath_type<scalar_t>;
@@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
141142
}
142143
}
143144
}
144-
#endif
145145

146-
#ifdef USE_ROCM
147-
#define SKIP_SORTED_INDICES 32
148146
template <typename scalar_t>
149147
__global__ void indexing_backward_kernel_stride_1(
150148
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -784,6 +782,67 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784782
kBool,
785783
kBFloat16);
786784
} else {
785+
#ifdef USE_ROCM
786+
if (num_indices < 200000) {
787+
AT_DISPATCH_V2(
788+
expandedValue.scalar_type(),
789+
"indexing_backward",
790+
AT_WRAP([&] {
791+
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
792+
sorted_indices.const_data_ptr<int64_t>(),
793+
orig_indices.const_data_ptr<int64_t>(),
794+
expandedValue.const_data_ptr<scalar_t>(),
795+
src_.mutable_data_ptr<scalar_t>(),
796+
num_indices,
797+
sliceSize,
798+
strideBefore,
799+
nElemBefore,
800+
accumulate);
801+
C10_CUDA_KERNEL_LAUNCH_CHECK();
802+
}),
803+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
804+
// AT_EXPAND(AT_FLOAT8_TYPES),
805+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
806+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
807+
kFloat8_e4m3fn,
808+
kFloat8_e5m2,
809+
kFloat8_e4m3fnuz,
810+
kFloat8_e5m2fnuz,
811+
kComplexHalf,
812+
kHalf,
813+
kBool,
814+
kBFloat16);
815+
} else {
816+
AT_DISPATCH_V2(
817+
expandedValue.scalar_type(),
818+
"indexing_backward_many_indices",
819+
AT_WRAP([&] {
820+
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
821+
sorted_indices.const_data_ptr<int64_t>(),
822+
orig_indices.const_data_ptr<int64_t>(),
823+
expandedValue.const_data_ptr<scalar_t>(),
824+
src_.mutable_data_ptr<scalar_t>(),
825+
num_indices,
826+
sliceSize,
827+
strideBefore,
828+
nElemBefore,
829+
accumulate);
830+
C10_CUDA_KERNEL_LAUNCH_CHECK();
831+
}),
832+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
833+
// AT_EXPAND(AT_FLOAT8_TYPES),
834+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
835+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
836+
kFloat8_e4m3fn,
837+
kFloat8_e5m2,
838+
kFloat8_e4m3fnuz,
839+
kFloat8_e5m2fnuz,
840+
kComplexHalf,
841+
kHalf,
842+
kBool,
843+
kBFloat16);
844+
}
845+
#else
787846
AT_DISPATCH_V2(
788847
expandedValue.scalar_type(),
789848
"indexing_backward",
@@ -812,6 +871,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
812871
kHalf,
813872
kBool,
814873
kBFloat16);
874+
#endif
815875
}
816876
}
817877

0 commit comments

Comments
 (0)