Skip to content

Commit a405a3e

Browse files
chunhuanMengtoyxumengfei25
authored
Fix deterministic indexing with broadcast (#1705)
Introduces enhancements to the `index_put` implementation for XPU tensors, focusing on deterministic behavior, improved shape handling, and expanded test coverage. Key changes include adding new helper functions, extending the `makeLinearIndex` and `computeLinearIndex` methods, and updating the associated test suite. ### Enhancements to `index_put` Implementation: * **New Helper Function for Shape Handling**: - Introduced `valsShape` to compute the target shape for expanded values during `index_put` operations. This simplifies and centralizes shape manipulation logic. (`src/ATen/native/xpu/sycl/Indexing.cpp`) * **Extended `makeLinearIndex` and `computeLinearIndex`**: - Added `dims_before` and `dims_indexed` to track dimensions before and during indexing. These are now returned as part of the tuple from `computeLinearIndex` and propagated through `makeLinearIndex`. (`src/ATen/native/xpu/sycl/IndexingUtils.h`) * **Simplified Value Expansion in `index_put_deterministic_kernel`**: - Replaced manual size inference and expansion logic with a call to `valsShape`. This makes the code more concise and reduces duplication. (`src/ATen/native/xpu/sycl/Indexing.cpp`) ### Test Suite Enhancements: * **New Deterministic Tests**: - Added a new test, `test_index_put_deterministic_with_optional_tensors`, to validate deterministic behavior of `index_put` with various tensor shapes and scenarios. This includes checks for shape mismatches and proper handling of 0D, 1D, and 2D values. (`test/xpu/test_indexing_xpu.py`) These changes collectively improve the robustness, maintainability, and test coverage of the `index_put` functionality for XPU tensors. --------- Co-authored-by: Yutao Xu <[email protected]> Co-authored-by: mengfei25 <[email protected]>
1 parent 7e51233 commit a405a3e

File tree

4 files changed

+99
-49
lines changed

4 files changed

+99
-49
lines changed

src/ATen/native/xpu/sycl/Indexing.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,21 @@ void index_put_kernel(
609609
}
610610
}
611611

612+
DimVector valsShape(
613+
IntArrayRef self_sizes,
614+
int64_t dims_before,
615+
int64_t dims_indexed,
616+
IntArrayRef replacement_shape) {
617+
auto shape = DimVector(self_sizes);
618+
int64_t end = dims_before + dims_indexed;
619+
shape.erase(shape.begin() + dims_before, shape.begin() + end);
620+
shape.insert(
621+
shape.begin() + dims_before,
622+
replacement_shape.begin(),
623+
replacement_shape.end());
624+
return shape;
625+
}
626+
612627
void index_put_deterministic_kernel(
613628
Tensor& self,
614629
const c10::List<std::optional<Tensor>>& indices,
@@ -633,30 +648,21 @@ void index_put_deterministic_kernel(
633648
bool self_contiguous = self.is_contiguous();
634649
auto self_ = self_contiguous ? self : self.contiguous();
635650
Tensor linearIndex, src, expandedValue = value;
636-
int64_t nElemBefore, strideBefore, sliceSize;
651+
int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed;
637652
std::vector<int64_t> inversePerm;
638653
std::tie(
639-
linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) =
640-
makeLinearIndex(self_, indices, !unsafe);
654+
linearIndex,
655+
src,
656+
nElemBefore,
657+
strideBefore,
658+
sliceSize,
659+
inversePerm,
660+
dims_before,
661+
dims_indexed) = makeLinearIndex(self_, indices, !unsafe);
662+
auto vals_shape =
663+
valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes());
641664
int64_t num_indices = linearIndex.numel();
642-
643-
if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
644-
auto expanded_size = at::DimVector(expandedValue.sizes());
645-
646-
auto size1 = expandedValue.sizes();
647-
auto size2 = linearIndex.sizes();
648-
if (are_expandable(size1, size2)) {
649-
expanded_size = infer_size_dimvector(size1, size2);
650-
}
651-
if (nElemBefore > 1) {
652-
expanded_size.insert(expanded_size.begin(), nElemBefore);
653-
}
654-
if (sliceSize > 1) {
655-
expanded_size.insert(expanded_size.end(), sliceSize);
656-
}
657-
expandedValue = expandedValue.expand(expanded_size);
658-
}
659-
expandedValue = expandedValue.contiguous();
665+
expandedValue = expandedValue.expand(vals_shape).contiguous();
660666

661667
if (num_indices > 0 && sliceSize > 0) {
662668
const bool permuted = !src.is_contiguous();

src/ATen/native/xpu/sycl/IndexingUtils.h

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ static std::vector<int64_t> computeLinearStride(const Tensor& tensor) {
5757
return stride;
5858
}
5959

60-
static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
61-
const Tensor& src,
62-
TensorList indices,
63-
bool check_range) {
60+
static std::tuple<Tensor, int64_t, int64_t, int64_t, int64_t, int64_t>
61+
computeLinearIndex(const Tensor& src, TensorList indices, bool check_range) {
6462
auto strides = computeLinearStride(src);
6563
const auto& device = src.options().device();
6664

@@ -70,8 +68,10 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
7068
// are not being index.
7169
Tensor linearIndex;
7270
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
71+
int64_t dims_before = 0, dims_indexed = 0;
7372
for (const auto i : c10::irange(src.dim())) {
7473
if (indices[i].defined()) {
74+
dims_indexed++;
7575
// Cast index to the longType matching src's device
7676
// This allows us to support ie indexing a xpu tensor with a cpu tensor
7777
Tensor index =
@@ -88,17 +88,30 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
8888
} else if (linearIndex.defined()) {
8989
nElemAfter *= src.size(i);
9090
} else {
91+
dims_before++;
9192
nElemBefore *= src.size(i);
9293
}
9394
}
9495

9596
return std::make_tuple(
96-
std::move(linearIndex), nElemBefore, strideBefore, nElemAfter);
97+
std::move(linearIndex),
98+
nElemBefore,
99+
strideBefore,
100+
nElemAfter,
101+
dims_before,
102+
dims_indexed);
97103
}
98104

99-
static std::
100-
tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>>
101-
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
105+
static std::tuple<
106+
Tensor,
107+
Tensor,
108+
int64_t,
109+
int64_t,
110+
int64_t,
111+
std::vector<int64_t>,
112+
int64_t,
113+
int64_t>
114+
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
102115
checkIndexTensorTypes(orig, /*allow_int*/ true);
103116
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more
104117
// LongTensors
@@ -121,10 +134,22 @@ static std::
121134
std::tie(self, indices, inversePerm) =
122135
transposeToFrontAndInvPerm(self, indices);
123136
}
124-
auto [linearIndex, nElemBefore, strideBefore, nElemAfter] =
125-
computeLinearIndex(self, indices, check_range);
137+
auto
138+
[linearIndex,
139+
nElemBefore,
140+
strideBefore,
141+
nElemAfter,
142+
dims_before,
143+
dims_indexed] = computeLinearIndex(self, indices, check_range);
126144
return std::make_tuple(
127-
linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm);
145+
linearIndex,
146+
self,
147+
nElemBefore,
148+
strideBefore,
149+
nElemAfter,
150+
inversePerm,
151+
dims_before,
152+
dims_indexed);
128153
}
129154

130155
} // namespace at::native::xpu

test/xpu/skip_list_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,8 +1036,6 @@
10361036
# https://github.com/intel/torch-xpu-ops/issues/461
10371037
"test_index_put_src_datatype_xpu_float8_e5m2",
10381038
"test_index_put_src_datatype_xpu_float8_e4m3fn",
1039-
# https://github.com/intel/torch-xpu-ops/issues/1702
1040-
"test_index_put_deterministic_with_optional_tensors_xpu",
10411039
),
10421040
"nn/test_pooling_xpu.py": None,
10431041
"nn/test_dropout_xpu.py": None,

test/xpu/test_indexing_xpu.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Owner(s): ["module: intel"]
22

33
from torch.testing._internal.common_device_type import instantiate_device_type_tests
4-
from torch.testing._internal.common_utils import run_tests
4+
from torch.testing._internal.common_utils import DeterministicGuard, run_tests
55

66
try:
77
from xpu_test_utils import XPUPatchForImport
@@ -14,15 +14,15 @@
1414

1515
torch.Tensor.is_cuda = torch.Tensor.is_xpu
1616

17-
def __test_index_put_accumulate_with_optional_tensors(self, device):
18-
# TODO: replace with a better solution.
19-
# Currently, here using torchscript to put None into indices.
20-
# on C++ it gives indices as a list of 2 optional tensors: first is null and
21-
# the second is a valid tensor.
22-
@torch.jit.script
17+
def __test_index_put_deterministic_with_optional_tensors(self, device):
2318
def func(x, i, v):
24-
idx = [None, i]
25-
x.index_put_(idx, v, accumulate=True)
19+
with DeterministicGuard(True):
20+
x[..., i] = v
21+
return x
22+
23+
def func1(x, i, v):
24+
with DeterministicGuard(True):
25+
x[i] = v
2626
return x
2727

2828
n = 4
@@ -32,17 +32,38 @@ def func(x, i, v):
3232
indices_dev = indices.to(device)
3333
value0d = torch.tensor(10.0)
3434
value1d = torch.tensor([1.0, 2.0])
35+
values2d = torch.randn(n, 1)
3536

36-
out_cuda = func(t_dev, indices_dev, value0d.xpu())
37-
out_cpu = func(t, indices, value0d)
37+
for val in (value0d, value1d, values2d):
38+
out_cuda = func(t_dev, indices_dev, val.to(device))
39+
out_cpu = func(t, indices, val)
40+
self.assertEqual(out_cuda.cpu(), out_cpu)
41+
42+
t = torch.zeros((5, 4))
43+
t_dev = t.to(device)
44+
indices = torch.tensor([1, 4, 3])
45+
indices_dev = indices.to(device)
46+
val = torch.randn(4)
47+
out_cuda = func1(t_dev, indices_dev, val.xpu())
48+
out_cpu = func1(t, indices, val)
3849
self.assertEqual(out_cuda.cpu(), out_cpu)
3950

40-
out_cuda = func(t_dev, indices_dev, value1d.xpu())
41-
out_cpu = func(t, indices, value1d)
51+
t = torch.zeros(2, 3, 4)
52+
ind = torch.tensor([0, 1])
53+
val = torch.randn(6, 2)
54+
with self.assertRaisesRegex(RuntimeError, "shape mismatch"):
55+
func(t, ind, val)
56+
57+
with self.assertRaisesRegex(RuntimeError, "must match"):
58+
func(t.to(device), ind.to(device), val.to(device))
59+
60+
val = torch.randn(2, 3, 1)
61+
out_cuda = func1(t.to(device), ind.to(device), val.to(device))
62+
out_cpu = func1(t, ind, val)
4263
self.assertEqual(out_cuda.cpu(), out_cpu)
4364

44-
TestIndexing.test_index_put_accumulate_with_optional_tensors = (
45-
__test_index_put_accumulate_with_optional_tensors
65+
TestIndexing.test_index_put_deterministic_with_optional_tensors = (
66+
__test_index_put_deterministic_with_optional_tensors
4667
)
4768

4869
instantiate_device_type_tests(NumpyTests, globals(), only_for=("xpu"), allow_xpu=True)

0 commit comments

Comments
 (0)