Skip to content

Commit 06a66ea

Browse files
support direct indexing for DenseTensor.dims() (PaddlePaddle#76331)
1 parent 01b7256 commit 06a66ea

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

paddle/phi/core/dense_tensor.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,46 @@ class PADDLE_API DenseTensor : public TensorBase,
8585
/// \return The dims of the tensor.
8686
const DDim& dims() const noexcept override { return meta_.dims; }
8787

88+
/// \brief Returns the size of the tensor along the specified dimension.
89+
/// Supports negative indices, which count from the last dimension.
90+
/// \param dim The dimension index to retrieve. Must be in the range [0, ndim)
91+
/// or [-ndim, -1]. \return The size of the tensor along the given dimension.
92+
/// \throws phi::errors::OutOfRange if the tensor is empty or the index is out
93+
/// of range.
94+
const int64_t dims(int dim) const {
95+
int ndim = meta_.dims.size();
96+
97+
// Ensure the tensor has at least one dimension
98+
PADDLE_ENFORCE_GE(ndim,
99+
1,
100+
phi::errors::OutOfRange(
101+
"dims expects at least a 1-dimensional tensor"));
102+
103+
// Check if the index is within the valid range [-ndim, ndim)
104+
PADDLE_ENFORCE_GE(
105+
dim,
106+
-ndim,
107+
phi::errors::OutOfRange(
108+
"dims: dimension index (%d) must be in range [-%d, %d)",
109+
dim,
110+
ndim,
111+
ndim));
112+
PADDLE_ENFORCE_LT(
113+
dim,
114+
ndim,
115+
phi::errors::OutOfRange(
116+
"dims: dimension index (%d) must be in range [-%d, %d)",
117+
dim,
118+
ndim,
119+
ndim));
120+
121+
// Handle negative indices
122+
if (dim < 0) {
123+
dim += ndim;
124+
}
125+
return meta_.dims[dim];
126+
}
127+
88128
/// \brief Returns the stride of the tensor.
89129
/// \return The stride of the tensor.
90130
const DDim& strides() const noexcept { return meta_.strides; }

test/cpp/phi/core/test_dense_tensor.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,56 @@ TEST(dense_tensor, shallow_copy) {
280280
"tensor_1 to have the same meta"));
281281
}
282282

283+
TEST(dense_tensor, dim_indexing) {
284+
const DDim dims({4, 3, 2, 0});
285+
const DataType dtype{DataType::INT8};
286+
const DataLayout layout{DataLayout::NHWC};
287+
const LegacyLoD lod{};
288+
DenseTensorMeta meta(dtype, dims, layout, lod);
289+
290+
auto fancy_allocator = std::unique_ptr<Allocator>(new FancyAllocator);
291+
auto* alloc = fancy_allocator.get();
292+
DenseTensor tensor_0(alloc, meta);
293+
int ndim = tensor_0.dims().size();
294+
auto tensor_0_dims = tensor_0.dims();
295+
for (int i = -ndim; i < ndim; ++i) {
296+
PADDLE_ENFORCE_EQ(
297+
tensor_0_dims[(i + ndim) % ndim],
298+
tensor_0.dims(i),
299+
common::errors::InvalidArgument(
300+
"Dimension mismatch at index %d. Expected: %d, but got: %d",
301+
i,
302+
tensor_0_dims[i],
303+
tensor_0.dims(i)));
304+
}
305+
306+
// throw exception for index >= ndim
307+
bool caught_exception = false;
308+
try {
309+
tensor_0.dims(ndim);
310+
} catch (const common::enforce::EnforceNotMet& error) {
311+
caught_exception = true;
312+
}
313+
PADDLE_ENFORCE_EQ(
314+
caught_exception,
315+
true,
316+
common::errors::InvalidArgument(
317+
"Expected an exception to be thrown for index >= ndim"));
318+
319+
// throw exception for index < -ndim
320+
caught_exception = false;
321+
try {
322+
tensor_0.dims(-ndim - 1);
323+
} catch (const common::enforce::EnforceNotMet& error) {
324+
caught_exception = true;
325+
}
326+
PADDLE_ENFORCE_EQ(
327+
caught_exception,
328+
true,
329+
common::errors::InvalidArgument(
330+
"Expected an exception to be thrown for index < -ndim"));
331+
}
332+
283333
TEST(dense_tensor, storage_properties) {
284334
const DataType dtype{DataType::FLOAT32};
285335
const DDim dims({1, 2});

0 commit comments

Comments
 (0)