Skip to content

Commit c34161d

Browse files
authored
[Compat] Fix transpose implementation and add negative indexing support for size (PaddlePaddle#75900)
1 parent c5e8259 commit c34161d

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

paddle/phi/api/include/compat/ATen/core/TensorBase.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class PADDLE_API TensorBase {
6666
}
6767

6868
int64_t size(int64_t dim) const {
69+
if (dim < 0) {
70+
dim += tensor_.dims().size();
71+
}
6972
return tensor_.dims()[static_cast<int>(dim)];
7073
}
7174

paddle/phi/api/include/compat/ATen/core/TensorBody.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ class Tensor : public TensorBase {
5353
}
5454

5555
using TensorBase::size;
56-
// int64_t size(int64_t dim) const {
57-
// return tensor_.dims()[static_cast<int>(dim)];
58-
// }
5956

6057
c10::IntArrayRef sizes() const {
6158
return compat::_PD_PhiDDimToIntArrayRef(tensor_.dims());
@@ -119,8 +116,12 @@ class Tensor : public TensorBase {
119116
}
120117

121118
at::Tensor transpose(int64_t dim0, int64_t dim1) const {
122-
return Tensor(paddle::experimental::transpose(
123-
tensor_, {static_cast<int>(dim0), static_cast<int>(dim1)}));
119+
std::vector<int> perm(tensor_.dims().size());
120+
for (size_t i = 0; i < perm.size(); i++) {
121+
perm[i] = static_cast<int>(i);
122+
}
123+
std::swap(perm[dim0], perm[dim1]);
124+
return Tensor(paddle::experimental::transpose(tensor_, perm));
124125
}
125126

126127
at::Tensor& copy_(const at::Tensor& src, bool non_blocking = false) const {

test/cpp/compat/compat_basic_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,17 @@ TEST(TestDevice, DeviceAPIsOnCPU) {
298298
auto options = cpu_tensor.options();
299299
ASSERT_EQ(options.device().type(), at::DeviceType::CPU);
300300
}
301+
302+
TEST(TestTranspose, TransposeAPI) {
303+
at::Tensor a = at::ones({4, 5, 6, 7, 8}, at::kFloat);
304+
at::Tensor b = a.transpose(2, 3);
305+
ASSERT_EQ(b.sizes(), c10::IntArrayRef({4, 5, 7, 6, 8}));
306+
}
307+
308+
TEST(TestSize, SizeNegativeIndex) {
309+
at::Tensor tensor = at::ones({2, 3, 4, 5}, at::kFloat);
310+
ASSERT_EQ(tensor.size(-1), 5);
311+
ASSERT_EQ(tensor.size(-2), 4);
312+
ASSERT_EQ(tensor.size(-3), 3);
313+
ASSERT_EQ(tensor.size(-4), 2);
314+
}

0 commit comments

Comments
 (0)