@@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656#endif
5757}
5858
59- #if 0
59+ #ifdef USE_ROCM
60+ #define SKIP_SORTED_INDICES 32
6061template <typename scalar_t , int SZ>
61- __global__ void indexing_backward_kernel (
62+ __global__ void indexing_backward_kernel_many_indices (
6263 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
6364 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6465 using opmath_t = at::opmath_type<scalar_t >;
@@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
141142 }
142143 }
143144}
144- #endif
145145
146- #ifdef USE_ROCM
147- #define SKIP_SORTED_INDICES 32
148146template <typename scalar_t >
149147__global__ void indexing_backward_kernel_stride_1 (
150148 const int64_t * sorted_indices, const int64_t * indices, const scalar_t * grad_output, scalar_t * grad_weight,
@@ -676,6 +674,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
676674 auto vals_shape = valsShape (src.sizes (), dims_before, dims_indexed, linearIndex.sizes ());
677675 int64_t num_indices = linearIndex.numel ();
678676 expandedValue = expandedValue.expand (vals_shape).contiguous ();
677+ printf (" num_indices = %d\n " , num_indices);
679678
680679 if (num_indices > 0 && sliceSize > 0 ) {
681680 const bool permuted = !src.is_contiguous ();
@@ -772,7 +771,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
772771 C10_CUDA_KERNEL_LAUNCH_CHECK ();
773772 }),
774773 AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
775- // AT_EXPAND(AT_FLOAT8_TYPES),
774+ // AT_EXPAND(AT_FLOAT8_TYPES),3
776775 // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
777776 // should not be supported here, then reenable AT_FLOAT8_DTYPES
778777 kFloat8_e4m3fn ,
@@ -784,6 +783,67 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784783 kBool ,
785784 kBFloat16 );
786785 } else {
786+ #ifdef USE_ROCM
787+ if (num_indices < 200000 ) {
788+ AT_DISPATCH_V2 (
789+ expandedValue.scalar_type (),
790+ " indexing_backward" ,
791+ AT_WRAP ([&] {
792+ indexing_backward_kernel<scalar_t , UNROLL><<<grid, block, 0 , stream>>> (
793+ sorted_indices.const_data_ptr <int64_t >(),
794+ orig_indices.const_data_ptr <int64_t >(),
795+ expandedValue.const_data_ptr <scalar_t >(),
796+ src_.mutable_data_ptr <scalar_t >(),
797+ num_indices,
798+ sliceSize,
799+ strideBefore,
800+ nElemBefore,
801+ accumulate);
802+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
803+ }),
804+ AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
805+ // AT_EXPAND(AT_FLOAT8_TYPES),
806+ // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
807+ // should not be supported here, then reenable AT_FLOAT8_DTYPES
808+ kFloat8_e4m3fn ,
809+ kFloat8_e5m2 ,
810+ kFloat8_e4m3fnuz ,
811+ kFloat8_e5m2fnuz ,
812+ kComplexHalf ,
813+ kHalf ,
814+ kBool ,
815+ kBFloat16 );
816+ } else {
817+ AT_DISPATCH_V2 (
818+ expandedValue.scalar_type (),
819+ " indexing_backward_many_indices" ,
820+ AT_WRAP ([&] {
821+ indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid, block, smem_dups_size, stream>>> (
822+ sorted_indices.const_data_ptr <int64_t >(),
823+ orig_indices.const_data_ptr <int64_t >(),
824+ expandedValue.const_data_ptr <scalar_t >(),
825+ src_.mutable_data_ptr <scalar_t >(),
826+ num_indices,
827+ sliceSize,
828+ strideBefore,
829+ nElemBefore,
830+ accumulate);
831+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
832+ }),
833+ AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
834+ // AT_EXPAND(AT_FLOAT8_TYPES),
835+ // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
836+ // should not be supported here, then reenable AT_FLOAT8_DTYPES
837+ kFloat8_e4m3fn ,
838+ kFloat8_e5m2 ,
839+ kFloat8_e4m3fnuz ,
840+ kFloat8_e5m2fnuz ,
841+ kComplexHalf ,
842+ kHalf ,
843+ kBool ,
844+ kBFloat16 );
845+ }
846+ #else
787847 AT_DISPATCH_V2 (
788848 expandedValue.scalar_type (),
789849 " indexing_backward" ,
@@ -812,6 +872,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
812872 kHalf ,
813873 kBool ,
814874 kBFloat16 );
875+ #endif
815876 }
816877 }
817878
0 commit comments