@@ -567,7 +567,7 @@ static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
567567 return stride;
568568}
569569
570- static std::tuple<Tensor, int64_t , int64_t , int64_t >
570+ static std::tuple<Tensor, int64_t , int64_t , int64_t , int64_t , int64_t >
571571computeLinearIndex (const Tensor & src, TensorList indices, bool check_range) {
572572 auto strides = computeLinearStride (src);
573573 const auto & device = src.options ().device ();
@@ -578,8 +578,10 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
578578 // are not being index.
579579 Tensor linearIndex;
580580 int64_t nElemBefore = 1 , nElemAfter = 1 , strideBefore =0 ;
581+ int64_t dims_before = 0 , dims_indexed = 0 ;
581582 for (const auto i: c10::irange (src.dim ())) {
582583 if (indices[i].defined ()) {
584+ dims_indexed++;
583585 // Cast index to the longType matching src's device
584586 // This allows us to support ie indexing a cuda tensor with a cpu tensor
585587 Tensor index = (wrapIndexOnce (indices[i], i, src.size (i), check_range) * strides[i]).to (device);
@@ -594,15 +596,17 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
594596 } else if (linearIndex.defined ()) {
595597 nElemAfter *= src.size (i);
596598 } else {
599+ dims_before++;
597600 nElemBefore *= src.size (i);
598601 }
599602 }
600603
601- return std::make_tuple (std::move (linearIndex), nElemBefore, strideBefore, nElemAfter);
604+ return std::make_tuple (std::move (linearIndex), nElemBefore, strideBefore, nElemAfter, dims_before, dims_indexed );
602605}
603606
604607
605- static std::tuple<Tensor, Tensor, int64_t , int64_t , int64_t , std::vector<int64_t >> makeLinearIndex (Tensor self, IOptTensorListRef orig, bool check_range) {
608+ static std::tuple<Tensor, Tensor, int64_t , int64_t , int64_t , std::vector<int64_t >, int64_t , int64_t >
609+ makeLinearIndex (Tensor self, IOptTensorListRef orig, bool check_range) {
606610 checkIndexTensorTypes (orig, /* allow_int*/ true );
607611 // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
608612 auto indices = expandTensors (self, orig);
@@ -623,13 +627,11 @@ static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t
623627 if (!hasContiguousSubspace (indices)) {
624628 std::tie (self, indices, inversePerm) = transposeToFrontAndInvPerm (self, indices);
625629 }
626- auto [linearIndex, nElemBefore, strideBefore, nElemAfter] = computeLinearIndex (self, indices, check_range);
627- return std::make_tuple (linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm);
630+ auto [linearIndex, nElemBefore, strideBefore, nElemAfter, dims_before, dims_indexed] =
631+ computeLinearIndex (self, indices, check_range);
632+ return std::make_tuple (linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm,
633+ dims_before, dims_indexed);
628634}
629-
630-
631- void index_put_with_sort_kernel_thrust_helper (Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
632-
633635namespace {
634636
635637int64_t largestIndex (const Tensor &self) {
@@ -640,6 +642,20 @@ int64_t largestIndex(const Tensor &self) {
640642 return result;
641643}
642644
645+ DimVector valsShape (IntArrayRef self_sizes,
646+ int64_t dims_before,
647+ int64_t dims_indexed,
648+ IntArrayRef replacement_shape) {
649+ auto shape = DimVector (self_sizes);
650+ int64_t end = dims_before + dims_indexed;
651+ shape.erase (shape.begin () + dims_before, shape.begin () + end);
652+ shape.insert (
653+ shape.begin () + dims_before,
654+ replacement_shape.begin (),
655+ replacement_shape.end ());
656+ return shape;
657+ }
658+
643659void index_put_with_sort_kernel (Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
644660 TORCH_CHECK (!indices.empty () || is_expandable_to (value.sizes (), self.sizes ()), " shape mismatch: value tensor of shape " , value.sizes (),
645661 " cannot be broadcast to indexing result of shape " , self.sizes ());
@@ -649,27 +665,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
649665 bool self_contiguous = self.is_contiguous ();
650666 auto self_ = self_contiguous ? self : self.contiguous ();
651667 Tensor linearIndex, src, expandedValue = value;
652- int64_t nElemBefore, strideBefore, sliceSize;
668+ int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed ;
653669 std::vector<int64_t > inversePerm;
654- std::tie (linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex (self_, indices, !unsafe);
670+ std::tie (linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm,
671+ dims_before, dims_indexed) = makeLinearIndex (self_, indices, !unsafe);
672+ auto vals_shape = valsShape (src.sizes (), dims_before, dims_indexed, linearIndex.sizes ());
655673 int64_t num_indices = linearIndex.numel ();
656-
657- if (expandedValue.numel () < num_indices * nElemBefore * sliceSize) {
658- auto expanded_size = at::DimVector (expandedValue.sizes ());
659- auto size1 = expandedValue.sizes ();
660- auto size2 = linearIndex.sizes ();
661- if (are_expandable (size1, size2)) {
662- expanded_size = infer_size_dimvector (size1, size2);
663- }
664- if (nElemBefore > 1 ) {
665- expanded_size.insert (expanded_size.begin (), nElemBefore);
666- }
667- if (sliceSize > 1 ) {
668- expanded_size.insert (expanded_size.end (), sliceSize);
669- }
670- expandedValue = expandedValue.expand (expanded_size);
671- }
672- expandedValue = expandedValue.contiguous ();
674+ expandedValue = expandedValue.expand (vals_shape).contiguous ();
673675
674676 if (num_indices > 0 && sliceSize > 0 ) {
675677 const bool permuted = !src.is_contiguous ();
@@ -681,15 +683,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
681683
682684 linearIndex.divide_ (sliceSize, " trunc" );
683685
684- // cub on CUDA <= 11.2 have a bug that for small sizes
685- // cub's sort can be much slower than thrust's merge sort
686- // this bug is fixed in CUDA 11.3
687- #if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) && !defined(USE_ROCM)
688- if (num_indices < 50000 ) {
689- index_put_with_sort_kernel_thrust_helper (linearIndex, orig_indices, sorted_indices, num_indices);
690- } else
691- #endif
692- {
693686 // Sort the inputs into sorted with the corresponding indices
694687 auto range = at::arange (num_indices, linearIndex.options ());
695688 // linearIndex can not be negative, and we take advantage of this
@@ -699,7 +692,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
699692 linearIndex.const_data_ptr <int64_t >(), sorted_indices.mutable_data_ptr <int64_t >(),
700693 range.const_data_ptr <int64_t >(), orig_indices.mutable_data_ptr <int64_t >(),
701694 num_indices, false , 0 , nbits);
702- }
695+
703696
704697 TORCH_INTERNAL_ASSERT (
705698 linearIndex.numel ()*sliceSize*nElemBefore == expandedValue.numel (),
@@ -838,24 +831,13 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
838831 bool self_contiguous = self.is_contiguous ();
839832 auto self_ = self_contiguous ? self : self.contiguous ();
840833 Tensor linearIndex, src, expandedValue = value;
841- int64_t nElemBefore, strideBefore, sliceSize;
834+ int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed ;
842835 std::vector<int64_t > inversePerm;
843- std::tie (linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex (self_, indices, !unsafe);
836+ std::tie (linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm,
837+ dims_before, dims_indexed) = makeLinearIndex (self_, indices, !unsafe);
838+ auto vals_shape = valsShape (src.sizes (), dims_before, dims_indexed, linearIndex.sizes ());
844839 int64_t num_indices = linearIndex.numel ();
845-
846- if (expandedValue.numel () < num_indices * nElemBefore * sliceSize) {
847- auto expanded_size = at::DimVector (expandedValue.sizes ());
848- auto size1 = expandedValue.sizes ();
849- auto size2 = linearIndex.sizes ();
850- if (are_expandable (size1, size2)) {
851- expanded_size = infer_size_dimvector (size1, size2);
852- }
853- if (nElemBefore > 1 ) {
854- expanded_size.insert (expanded_size.begin (), nElemBefore);
855- }
856- expandedValue = expandedValue.expand (expanded_size);
857- }
858- expandedValue = expandedValue.contiguous ();
840+ expandedValue = expandedValue.expand (vals_shape).contiguous ();
859841
860842 if (num_indices > 0 && sliceSize > 0 ) {
861843 const bool permuted = !src.is_contiguous ();
@@ -867,15 +849,6 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
867849
868850 linearIndex.divide_ (sliceSize, " trunc" );
869851
870- // cub on CUDA <= 11.2 have a bug that for small sizes
871- // cub's sort can be much slower than thrust's merge sort
872- // this bug is fixed in CUDA 11.3
873- #if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) && !defined(USE_ROCM)
874- if (num_indices < 50000 ) {
875- index_put_with_sort_kernel_thrust_helper (linearIndex, orig_indices, sorted_indices, num_indices);
876- } else
877- #endif
878- {
879852 // Sort the inputs into sorted with the corresponding indices
880853 auto range = at::arange (num_indices, linearIndex.options ());
881854 // linearIndex can not be negative, and we take advantage of this
@@ -885,7 +858,7 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
885858 linearIndex.const_data_ptr <int64_t >(), sorted_indices.mutable_data_ptr <int64_t >(),
886859 range.const_data_ptr <int64_t >(), orig_indices.mutable_data_ptr <int64_t >(),
887860 num_indices, false , 0 , nbits);
888- }
861+
889862
890863 TORCH_INTERNAL_ASSERT (
891864 linearIndex.numel ()*sliceSize*nElemBefore == expandedValue.numel (),
0 commit comments