@@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656#endif
5757}
5858
59- #if 0
59+ #ifdef USE_ROCM
60+ #define SKIP_SORTED_INDICES 32
6061template <typename scalar_t , int SZ>
61- __global__ void indexing_backward_kernel (
62+ __global__ void indexing_backward_kernel_many_indices (
6263 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
6364 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6465 using opmath_t = at::opmath_type<scalar_t >;
@@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
141142 }
142143 }
143144}
144- #endif
145145
146- #ifdef USE_ROCM
147- #define SKIP_SORTED_INDICES 32
148146template <typename scalar_t >
149147__global__ void indexing_backward_kernel_stride_1 (
150148 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -784,6 +782,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784782 kBool ,
785783 kBFloat16 );
786784 } else {
785+ #ifdef USE_ROCM
786+ if (num_indices >= 200000 )
787+ AT_DISPATCH_V2 (
788+ expandedValue.scalar_type (),
789+ " indexing_backward_many_indices" ,
790+ AT_WRAP ([&] {
791+ indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid, block, smem_dups_size, stream>>> (
792+ sorted_indices.const_data_ptr <int64_t >(),
793+ orig_indices.const_data_ptr <int64_t >(),
794+ expandedValue.const_data_ptr <scalar_t >(),
795+ src_.mutable_data_ptr <scalar_t >(),
796+ num_indices,
797+ sliceSize,
798+ strideBefore,
799+ nElemBefore,
800+ accumulate);
801+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
802+ }),
803+ AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
804+ // AT_EXPAND(AT_FLOAT8_TYPES),
805+ // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
806+ // should not be supported here, then reenable AT_FLOAT8_DTYPES
807+ kFloat8_e4m3fn ,
808+ kFloat8_e5m2 ,
809+ kFloat8_e4m3fnuz ,
810+ kFloat8_e5m2fnuz ,
811+ kComplexHalf ,
812+ kHalf ,
813+ kBool ,
814+ kBFloat16 );
815+ else
816+ #endif
787817 AT_DISPATCH_V2 (
788818 expandedValue.scalar_type (),
789819 " indexing_backward" ,
0 commit comments