Skip to content

Commit a891032

Browse files
authored
[Cherry-pick] Fix dtype unmatched in custom op API #31306
[Cherry-pick] Fix dtype unmatched in custom op API cherry-pick of #31305
1 parent 628f085 commit a891032

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

paddle/fluid/extension/include/ext_tensor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class PD_DLL_DECL Tensor {
5757
/// Reshape must be called before calling
5858
/// mutable_data() or copy_to(const PlaceType& place)
5959
/// \param shape The shape to set.
60-
void reshape(const std::vector<int>& shape);
60+
void reshape(const std::vector<int64_t>& shape);
6161

6262
/// \brief Get the memory pointer in CPU or GPU with
6363
/// specific data type.
@@ -90,7 +90,7 @@ class PD_DLL_DECL Tensor {
9090
Tensor copy_to(const PlaceType& place) const;
9191

9292
/// \brief Return the shape of the Tensor.
93-
std::vector<int> shape() const;
93+
std::vector<int64_t> shape() const;
9494

9595
/// \brief Return the data type of the tensor.
9696
/// It's usually used to get the output tensor data type.

paddle/fluid/extension/src/ext_tensor.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
9595
} \
9696
auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());
9797

98-
void Tensor::reshape(const std::vector<int> &shape) {
98+
void Tensor::reshape(const std::vector<int64_t> &shape) {
9999
GET_CASTED_TENSOR
100100
tensor->Resize(framework::make_ddim(shape));
101101
}
@@ -251,9 +251,9 @@ template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
251251
const PlaceType &place);
252252
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
253253

254-
std::vector<int> Tensor::shape() const {
254+
std::vector<int64_t> Tensor::shape() const {
255255
GET_CASTED_TENSOR
256-
return framework::vectorize<int>(tensor->dims());
256+
return framework::vectorize<int64_t>(tensor->dims());
257257
}
258258

259259
const PlaceType &Tensor::place() const {

paddle/fluid/framework/custom_tensor_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
template <typename T>
2222
paddle::Tensor InitCPUTensorForTest() {
23-
std::vector<int> tensor_shape{5, 5};
23+
std::vector<int64_t> tensor_shape{5, 5};
2424
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
2525
t1.reshape(tensor_shape);
2626
auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU);
@@ -54,7 +54,7 @@ void TestCopyTensor() {
5454
}
5555

5656
void TestAPIPlace() {
57-
std::vector<int> tensor_shape = {5, 5};
57+
std::vector<int64_t> tensor_shape = {5, 5};
5858
#ifdef PADDLE_WITH_CUDA
5959
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU);
6060
t1.reshape(tensor_shape);
@@ -68,7 +68,7 @@ void TestAPIPlace() {
6868
}
6969

7070
void TestAPISizeAndShape() {
71-
std::vector<int> tensor_shape = {5, 5};
71+
std::vector<int64_t> tensor_shape = {5, 5};
7272
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
7373
t1.reshape(tensor_shape);
7474
CHECK_EQ(t1.size(), 25);
@@ -77,7 +77,7 @@ void TestAPISizeAndShape() {
7777

7878
template <typename T>
7979
paddle::DataType TestDtype() {
80-
std::vector<int> tensor_shape = {5, 5};
80+
std::vector<int64_t> tensor_shape = {5, 5};
8181
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
8282
t1.reshape(tensor_shape);
8383
t1.template mutable_data<T>();
@@ -86,7 +86,7 @@ paddle::DataType TestDtype() {
8686

8787
template <typename T>
8888
void TestCast(paddle::DataType data_type) {
89-
std::vector<int> tensor_shape = {5, 5};
89+
std::vector<int64_t> tensor_shape = {5, 5};
9090
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
9191
t1.reshape(tensor_shape);
9292
t1.template mutable_data<T>();

0 commit comments

Comments
 (0)