@@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656#endif
5757}
5858
59- #if 0
59+ #ifdef USE_ROCM
60+ #define SKIP_SORTED_INDICES 32
6061template <typename scalar_t , int SZ>
61- __global__ void indexing_backward_kernel (
62+ __global__ void indexing_backward_kernel_many_indices (
6263 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
6364 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6465 using opmath_t = at::opmath_type<scalar_t >;
@@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
141142 }
142143 }
143144}
144- #endif
145145
146- #ifdef USE_ROCM
147- #define SKIP_SORTED_INDICES 32
148146template <typename scalar_t >
149147__global__ void indexing_backward_kernel_stride_1 (
150148 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -784,6 +782,67 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784782 kBool ,
785783 kBFloat16 );
786784 } else {
785+ #ifdef USE_ROCM
786+ if (num_indices < 200000 ) {
787+ AT_DISPATCH_V2 (
788+ expandedValue.scalar_type (),
789+ " indexing_backward" ,
790+ AT_WRAP ([&] {
791+ indexing_backward_kernel<scalar_t , UNROLL><<<grid, block, 0 , stream>>> (
792+ sorted_indices.const_data_ptr <int64_t >(),
793+ orig_indices.const_data_ptr <int64_t >(),
794+ expandedValue.const_data_ptr <scalar_t >(),
795+ src_.mutable_data_ptr <scalar_t >(),
796+ num_indices,
797+ sliceSize,
798+ strideBefore,
799+ nElemBefore,
800+ accumulate);
801+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
802+ }),
803+ AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
804+ // AT_EXPAND(AT_FLOAT8_TYPES),
805+ // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
806+ // should not be supported here, then reenable AT_FLOAT8_DTYPES
807+ kFloat8_e4m3fn ,
808+ kFloat8_e5m2 ,
809+ kFloat8_e4m3fnuz ,
810+ kFloat8_e5m2fnuz ,
811+ kComplexHalf ,
812+ kHalf ,
813+ kBool ,
814+ kBFloat16 );
815+ } else {
816+ AT_DISPATCH_V2 (
817+ expandedValue.scalar_type (),
818+ " indexing_backward_many_indices" ,
819+ AT_WRAP ([&] {
820+ indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid, block, smem_dups_size, stream>>> (
821+ sorted_indices.const_data_ptr <int64_t >(),
822+ orig_indices.const_data_ptr <int64_t >(),
823+ expandedValue.const_data_ptr <scalar_t >(),
824+ src_.mutable_data_ptr <scalar_t >(),
825+ num_indices,
826+ sliceSize,
827+ strideBefore,
828+ nElemBefore,
829+ accumulate);
830+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
831+ }),
832+ AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
833+ // AT_EXPAND(AT_FLOAT8_TYPES),
834+ // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
835+ // should not be supported here, then reenable AT_FLOAT8_DTYPES
836+ kFloat8_e4m3fn ,
837+ kFloat8_e5m2 ,
838+ kFloat8_e4m3fnuz ,
839+ kFloat8_e5m2fnuz ,
840+ kComplexHalf ,
841+ kHalf ,
842+ kBool ,
843+ kBFloat16 );
844+ }
845+ #else
787846 AT_DISPATCH_V2 (
788847 expandedValue.scalar_type (),
789848 " indexing_backward" ,
@@ -812,6 +871,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
812871 kHalf ,
813872 kBool ,
814873 kBFloat16 );
874+ #endif
815875 }
816876 }
817877
0 commit comments