@@ -55,9 +55,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5555#endif
5656}
5757
58- #if 0
58+ #ifdef USE_ROCM
59+ #define SKIP_SORTED_INDICES 32
5960template <typename scalar_t , int SZ>
60- __global__ void indexing_backward_kernel (
61+ __global__ void indexing_backward_kernel_many_indices (
6162 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
6263 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6364 using opmath_t = at::opmath_type<scalar_t >;
@@ -140,10 +141,7 @@ __global__ void indexing_backward_kernel(
140141 }
141142 }
142143}
143- #endif
144144
145- #ifdef USE_ROCM
146- #define SKIP_SORTED_INDICES 32
147145template <typename scalar_t >
148146__global__ void indexing_backward_kernel_stride_1 (
149147 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -790,6 +788,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
790788 kBool ,
791789 kBFloat16 );
792790 } else {
791+ #ifdef USE_ROCM
792+ if (num_indices >= 200000 )
793+ AT_DISPATCH_V2 (
794+ expandedValue.scalar_type (),
795+ " indexing_backward_many_indices" ,
796+ AT_WRAP ([&] {
797+ indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid, block, smem_dups_size, stream>>> (
798+ sorted_indices.const_data_ptr <int64_t >(),
799+ orig_indices.const_data_ptr <int64_t >(),
800+ expandedValue.const_data_ptr <scalar_t >(),
801+ src_.mutable_data_ptr <scalar_t >(),
802+ num_indices,
803+ sliceSize,
804+ strideBefore,
805+ nElemBefore,
806+ accumulate);
807+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
808+ }),
809+ AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
810+ // AT_EXPAND(AT_FLOAT8_TYPES),
811+ // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
812+ // should not be supported here, then reenable AT_FLOAT8_DTYPES
813+ kFloat8_e4m3fn ,
814+ kFloat8_e5m2 ,
815+ kFloat8_e4m3fnuz ,
816+ kFloat8_e5m2fnuz ,
817+ kComplexHalf ,
818+ kHalf ,
819+ kBool ,
820+ kBFloat16 );
821+ else
822+ #endif
793823 AT_DISPATCH_V2 (
794824 expandedValue.scalar_type (),
795825 " indexing_backward" ,
0 commit comments