Skip to content

Commit b04852e

Browse files
ngimelpytorchmergebot
authored andcommitted
Fix deterministic indexing with broadcast (pytorch#154296)
Fixes pytorch#79987, now for real. Also removed thrust sort path that was needed for cuda <=11.2 because we no longer support it. Pull Request resolved: pytorch#154296 Approved by: https://github.com/soumith
1 parent c310006 commit b04852e

File tree

4 files changed

+72
-100
lines changed

4 files changed

+72
-100
lines changed

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,8 @@ Tensor& _index_put_impl_(
994994
}
995995
if ((self.device().type() == DeviceType::CUDA ||
996996
self.device().type() == DeviceType::XPU) &&
997-
(accumulate || globalContext().deterministicAlgorithms())) {
997+
(accumulate ||
998+
(globalContext().deterministicAlgorithms() && value_.numel() > 1))) {
998999
TORCH_CHECK(
9991000
value_.device() == self.device(),
10001001
"expected device ",

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
567567
return stride;
568568
}
569569

570-
static std::tuple<Tensor, int64_t, int64_t, int64_t>
570+
static std::tuple<Tensor, int64_t, int64_t, int64_t, int64_t, int64_t>
571571
computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
572572
auto strides = computeLinearStride(src);
573573
const auto& device = src.options().device();
@@ -578,8 +578,10 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
578578
// are not being index.
579579
Tensor linearIndex;
580580
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore =0;
581+
int64_t dims_before = 0, dims_indexed = 0;
581582
for (const auto i: c10::irange(src.dim())) {
582583
if (indices[i].defined()) {
584+
dims_indexed++;
583585
// Cast index to the longType matching src's device
584586
// This allows us to support ie indexing a cuda tensor with a cpu tensor
585587
Tensor index = (wrapIndexOnce(indices[i], i, src.size(i), check_range) * strides[i]).to(device);
@@ -594,15 +596,17 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
594596
} else if (linearIndex.defined()) {
595597
nElemAfter *= src.size(i);
596598
} else {
599+
dims_before++;
597600
nElemBefore *= src.size(i);
598601
}
599602
}
600603

601-
return std::make_tuple(std::move(linearIndex), nElemBefore, strideBefore, nElemAfter);
604+
return std::make_tuple(std::move(linearIndex), nElemBefore, strideBefore, nElemAfter, dims_before, dims_indexed);
602605
}
603606

604607

605-
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
608+
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>, int64_t, int64_t>
609+
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
606610
checkIndexTensorTypes(orig, /*allow_int*/true);
607611
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
608612
auto indices = expandTensors(self, orig);
@@ -623,13 +627,11 @@ static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t
623627
if (!hasContiguousSubspace(indices)) {
624628
std::tie(self, indices, inversePerm) = transposeToFrontAndInvPerm(self, indices);
625629
}
626-
auto [linearIndex, nElemBefore, strideBefore, nElemAfter] = computeLinearIndex(self, indices, check_range);
627-
return std::make_tuple(linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm);
630+
auto [linearIndex, nElemBefore, strideBefore, nElemAfter, dims_before, dims_indexed] =
631+
computeLinearIndex(self, indices, check_range);
632+
return std::make_tuple(linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm,
633+
dims_before, dims_indexed);
628634
}
629-
630-
631-
void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
632-
633635
namespace {
634636

635637
int64_t largestIndex(const Tensor &self) {
@@ -640,6 +642,20 @@ int64_t largestIndex(const Tensor &self) {
640642
return result;
641643
}
642644

645+
DimVector valsShape(IntArrayRef self_sizes,
646+
int64_t dims_before,
647+
int64_t dims_indexed,
648+
IntArrayRef replacement_shape) {
649+
auto shape = DimVector(self_sizes);
650+
int64_t end = dims_before + dims_indexed;
651+
shape.erase(shape.begin() + dims_before, shape.begin() + end);
652+
shape.insert(
653+
shape.begin() + dims_before,
654+
replacement_shape.begin(),
655+
replacement_shape.end());
656+
return shape;
657+
}
658+
643659
void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
644660
TORCH_CHECK(!indices.empty() || is_expandable_to(value.sizes(), self.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
645661
" cannot be broadcast to indexing result of shape ", self.sizes());
@@ -649,27 +665,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
649665
bool self_contiguous = self.is_contiguous();
650666
auto self_ = self_contiguous ? self : self.contiguous();
651667
Tensor linearIndex, src, expandedValue = value;
652-
int64_t nElemBefore, strideBefore, sliceSize;
668+
int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed;
653669
std::vector<int64_t> inversePerm;
654-
std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe);
670+
std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm,
671+
dims_before, dims_indexed) = makeLinearIndex(self_, indices, !unsafe);
672+
auto vals_shape = valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes());
655673
int64_t num_indices = linearIndex.numel();
656-
657-
if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
658-
auto expanded_size = at::DimVector(expandedValue.sizes());
659-
auto size1 = expandedValue.sizes();
660-
auto size2 = linearIndex.sizes();
661-
if (are_expandable(size1, size2)) {
662-
expanded_size = infer_size_dimvector(size1, size2);
663-
}
664-
if (nElemBefore > 1) {
665-
expanded_size.insert(expanded_size.begin(), nElemBefore);
666-
}
667-
if (sliceSize > 1) {
668-
expanded_size.insert(expanded_size.end(), sliceSize);
669-
}
670-
expandedValue = expandedValue.expand(expanded_size);
671-
}
672-
expandedValue = expandedValue.contiguous();
674+
expandedValue = expandedValue.expand(vals_shape).contiguous();
673675

674676
if (num_indices > 0 && sliceSize > 0) {
675677
const bool permuted = !src.is_contiguous();
@@ -681,15 +683,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
681683

682684
linearIndex.divide_(sliceSize, "trunc");
683685

684-
// cub on CUDA <= 11.2 have a bug that for small sizes
685-
// cub's sort can be much slower than thrust's merge sort
686-
// this bug is fixed in CUDA 11.3
687-
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) && !defined(USE_ROCM)
688-
if (num_indices < 50000) {
689-
index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
690-
} else
691-
#endif
692-
{
693686
// Sort the inputs into sorted with the corresponding indices
694687
auto range = at::arange(num_indices, linearIndex.options());
695688
// linearIndex can not be negative, and we take advantage of this
@@ -699,7 +692,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
699692
linearIndex.const_data_ptr<int64_t>(), sorted_indices.mutable_data_ptr<int64_t>(),
700693
range.const_data_ptr<int64_t>(), orig_indices.mutable_data_ptr<int64_t>(),
701694
num_indices, false, 0, nbits);
702-
}
695+
703696

704697
TORCH_INTERNAL_ASSERT(
705698
linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
@@ -838,24 +831,13 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
838831
bool self_contiguous = self.is_contiguous();
839832
auto self_ = self_contiguous ? self : self.contiguous();
840833
Tensor linearIndex, src, expandedValue = value;
841-
int64_t nElemBefore, strideBefore, sliceSize;
834+
int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed;
842835
std::vector<int64_t> inversePerm;
843-
std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe);
836+
std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm,
837+
dims_before, dims_indexed) = makeLinearIndex(self_, indices, !unsafe);
838+
auto vals_shape = valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes());
844839
int64_t num_indices = linearIndex.numel();
845-
846-
if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
847-
auto expanded_size = at::DimVector(expandedValue.sizes());
848-
auto size1 = expandedValue.sizes();
849-
auto size2 = linearIndex.sizes();
850-
if (are_expandable(size1, size2)) {
851-
expanded_size = infer_size_dimvector(size1, size2);
852-
}
853-
if (nElemBefore > 1) {
854-
expanded_size.insert(expanded_size.begin(), nElemBefore);
855-
}
856-
expandedValue = expandedValue.expand(expanded_size);
857-
}
858-
expandedValue = expandedValue.contiguous();
840+
expandedValue = expandedValue.expand(vals_shape).contiguous();
859841

860842
if (num_indices > 0 && sliceSize > 0) {
861843
const bool permuted = !src.is_contiguous();
@@ -867,15 +849,6 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
867849

868850
linearIndex.divide_(sliceSize, "trunc");
869851

870-
// cub on CUDA <= 11.2 have a bug that for small sizes
871-
// cub's sort can be much slower than thrust's merge sort
872-
// this bug is fixed in CUDA 11.3
873-
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) && !defined(USE_ROCM)
874-
if (num_indices < 50000) {
875-
index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
876-
} else
877-
#endif
878-
{
879852
// Sort the inputs into sorted with the corresponding indices
880853
auto range = at::arange(num_indices, linearIndex.options());
881854
// linearIndex can not be negative, and we take advantage of this
@@ -885,7 +858,7 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
885858
linearIndex.const_data_ptr<int64_t>(), sorted_indices.mutable_data_ptr<int64_t>(),
886859
range.const_data_ptr<int64_t>(), orig_indices.mutable_data_ptr<int64_t>(),
887860
num_indices, false, 0, nbits);
888-
}
861+
889862

890863
TORCH_INTERNAL_ASSERT(
891864
linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),

aten/src/ATen/native/cuda/LegacyThrustHelpers.cu

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,6 @@
1919

2020
namespace at::native {
2121

22-
void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
23-
sorted_indices.copy_(linearIndex);
24-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
25-
at::cuda::ThrustAllocator allocator;
26-
auto policy = thrust::cuda::par(allocator).on(stream);
27-
28-
using device_ptr = thrust::device_ptr<int64_t>;
29-
30-
// Fill sortedOrigIndices with sequential indices
31-
const auto count_iter = thrust::counting_iterator<int64_t>(0);
32-
auto orig_data = device_ptr(orig_indices.mutable_data_ptr<int64_t>());
33-
thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
34-
35-
// Sort the inputs into sorted with the corresponding indices; we
36-
// don't need a stable or multidimensional sort, so just use Thrust
37-
// directly
38-
// Sort; a stable sort is not required
39-
// NB - not passing comparator causes thrust to use radix sort, and it hurts perf A LOT, at least for medium (few K) sized indices
40-
auto sorted_data = device_ptr(sorted_indices.mutable_data_ptr<int64_t>());
41-
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>());
42-
}
43-
4422
#if !CUB_SUPPORTS_SCAN_BY_KEY()
4523

4624
template<typename index_t>

test/test_indexing.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,16 +1052,15 @@ def test_index_put_accumulate_non_contiguous(self, device):
10521052
self.assertEqual(out_cuda.cpu(), out_cpu)
10531053

10541054
@onlyCUDA
1055-
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1056-
def test_index_put_accumulate_with_optional_tensors(self, device):
1057-
# TODO: replace with a better solution.
1058-
# Currently, here using torchscript to put None into indices.
1059-
# on C++ it gives indices as a list of 2 optional tensors: first is null and
1060-
# the second is a valid tensor.
1061-
@torch.jit.script
1055+
def test_index_put_deterministic_with_optional_tensors(self, device):
10621056
def func(x, i, v):
1063-
idx = [None, i]
1064-
x.index_put_(idx, v, accumulate=True)
1057+
with DeterministicGuard(True):
1058+
x[..., i] = v
1059+
return x
1060+
1061+
def func1(x, i, v):
1062+
with DeterministicGuard(True):
1063+
x[i] = v
10651064
return x
10661065

10671066
n = 4
@@ -1071,13 +1070,34 @@ def func(x, i, v):
10711070
indices_dev = indices.to(device)
10721071
value0d = torch.tensor(10.0)
10731072
value1d = torch.tensor([1.0, 2.0])
1073+
values2d = torch.randn(n, 1)
1074+
1075+
for val in (value0d, value1d, values2d):
1076+
out_cuda = func(t_dev, indices_dev, val.to(device))
1077+
out_cpu = func(t, indices, val)
1078+
self.assertEqual(out_cuda.cpu(), out_cpu)
10741079

1075-
out_cuda = func(t_dev, indices_dev, value0d.cuda())
1076-
out_cpu = func(t, indices, value0d)
1080+
t = torch.zeros((5, 4))
1081+
t_dev = t.to(device)
1082+
indices = torch.tensor([1, 4, 3])
1083+
indices_dev = indices.to(device)
1084+
val = torch.randn(4)
1085+
out_cuda = func1(t_dev, indices_dev, val.cuda())
1086+
out_cpu = func1(t, indices, val)
10771087
self.assertEqual(out_cuda.cpu(), out_cpu)
10781088

1079-
out_cuda = func(t_dev, indices_dev, value1d.cuda())
1080-
out_cpu = func(t, indices, value1d)
1089+
t = torch.zeros(2, 3, 4)
1090+
ind = torch.tensor([0, 1])
1091+
val = torch.randn(6, 2)
1092+
with self.assertRaisesRegex(RuntimeError, "shape mismatch"):
1093+
func(t, ind, val)
1094+
1095+
with self.assertRaisesRegex(RuntimeError, "must match"):
1096+
func(t.to(device), ind.to(device), val.to(device))
1097+
1098+
val = torch.randn(2, 3, 1)
1099+
out_cuda = func1(t.to(device), ind.to(device), val.to(device))
1100+
out_cpu = func1(t, ind, val)
10811101
self.assertEqual(out_cuda.cpu(), out_cpu)
10821102

10831103
@onlyNativeDeviceTypes

0 commit comments

Comments
 (0)