Skip to content

Commit cf04d85

Browse files
committed
Fix non-stride-one backwards indexing performance
1 parent 0b82d9a commit cf04d85

File tree

1 file changed

+67
-6
lines changed

1 file changed

+67
-6
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
5656
#endif
5757
}
5858

59-
#if 0
59+
#ifdef USE_ROCM
60+
#define SKIP_SORTED_INDICES 32
6061
template <typename scalar_t, int SZ>
61-
__global__ void indexing_backward_kernel(
62+
__global__ void indexing_backward_kernel_many_indices(
6263
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
6364
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
6465
using opmath_t = at::opmath_type<scalar_t>;
@@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
141142
}
142143
}
143144
}
144-
#endif
145145

146-
#ifdef USE_ROCM
147-
#define SKIP_SORTED_INDICES 32
148146
template <typename scalar_t>
149147
__global__ void indexing_backward_kernel_stride_1(
150148
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -676,6 +674,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
676674
auto vals_shape = valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes());
677675
int64_t num_indices = linearIndex.numel();
678676
expandedValue = expandedValue.expand(vals_shape).contiguous();
677+
printf("num_indices = %d\n", num_indices);
679678

680679
if (num_indices > 0 && sliceSize > 0) {
681680
const bool permuted = !src.is_contiguous();
@@ -772,7 +771,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
772771
C10_CUDA_KERNEL_LAUNCH_CHECK();
773772
}),
774773
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
775-
// AT_EXPAND(AT_FLOAT8_TYPES),
774+
// AT_EXPAND(AT_FLOAT8_TYPES),3
776775
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
777776
// should not be supported here, then reenable AT_FLOAT8_DTYPES
778777
kFloat8_e4m3fn,
@@ -784,6 +783,67 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784783
kBool,
785784
kBFloat16);
786785
} else {
786+
#ifdef USE_ROCM
787+
if (num_indices < 200000) {
788+
AT_DISPATCH_V2(
789+
expandedValue.scalar_type(),
790+
"indexing_backward",
791+
AT_WRAP([&] {
792+
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
793+
sorted_indices.const_data_ptr<int64_t>(),
794+
orig_indices.const_data_ptr<int64_t>(),
795+
expandedValue.const_data_ptr<scalar_t>(),
796+
src_.mutable_data_ptr<scalar_t>(),
797+
num_indices,
798+
sliceSize,
799+
strideBefore,
800+
nElemBefore,
801+
accumulate);
802+
C10_CUDA_KERNEL_LAUNCH_CHECK();
803+
}),
804+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
805+
// AT_EXPAND(AT_FLOAT8_TYPES),
806+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
807+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
808+
kFloat8_e4m3fn,
809+
kFloat8_e5m2,
810+
kFloat8_e4m3fnuz,
811+
kFloat8_e5m2fnuz,
812+
kComplexHalf,
813+
kHalf,
814+
kBool,
815+
kBFloat16);
816+
} else {
817+
AT_DISPATCH_V2(
818+
expandedValue.scalar_type(),
819+
"indexing_backward_many_indices",
820+
AT_WRAP([&] {
821+
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
822+
sorted_indices.const_data_ptr<int64_t>(),
823+
orig_indices.const_data_ptr<int64_t>(),
824+
expandedValue.const_data_ptr<scalar_t>(),
825+
src_.mutable_data_ptr<scalar_t>(),
826+
num_indices,
827+
sliceSize,
828+
strideBefore,
829+
nElemBefore,
830+
accumulate);
831+
C10_CUDA_KERNEL_LAUNCH_CHECK();
832+
}),
833+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
834+
// AT_EXPAND(AT_FLOAT8_TYPES),
835+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
836+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
837+
kFloat8_e4m3fn,
838+
kFloat8_e5m2,
839+
kFloat8_e4m3fnuz,
840+
kFloat8_e5m2fnuz,
841+
kComplexHalf,
842+
kHalf,
843+
kBool,
844+
kBFloat16);
845+
}
846+
#else
787847
AT_DISPATCH_V2(
788848
expandedValue.scalar_type(),
789849
"indexing_backward",
@@ -812,6 +872,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
812872
kHalf,
813873
kBool,
814874
kBFloat16);
875+
#endif
815876
}
816877
}
817878

0 commit comments

Comments
 (0)