Skip to content

Commit e256d5c

Browse files
【NPU, MLU】fix interpolate_kernel of type miss (#1618)
1 parent df14333 commit e256d5c

File tree

4 files changed

+72
-40
lines changed

4 files changed

+72
-40
lines changed

backends/mlu/kernels/funcs/mlu_funcs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ inline void TensorToVector(const phi::CustomContext& ctx,
240240

241241
if (src_place.GetType() == phi::AllocationType::CUSTOM) {
242242
MemCpyD2H(&device, dst_ptr, src_ptr, size);
243+
} else if (src_place.GetType() == phi::AllocationType::CPU) {
244+
std::memcpy(dst_ptr, src_ptr, size);
243245
} else {
244246
PADDLE_THROW(phi::errors::Unimplemented(
245247
"TensorToVector on %s is not supported.", src_place));

backends/mlu/kernels/interpolate_kernel.cc

100755100644
Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,24 @@ inline std::vector<int> get_new_shape_mlu(
2323
std::vector<int> vec_new_shape;
2424
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
2525
auto tensor = list_new_shape_tensor[i];
26-
PADDLE_ENFORCE_EQ(
27-
tensor->dims(),
28-
phi::make_ddim({1}),
29-
phi::errors::InvalidArgument("shape of dim tensor should be [1]"));
30-
std::vector<int32_t> temp_vec(1);
31-
dev_ctx.Wait();
32-
TensorToVector(dev_ctx, *tensor, dev_ctx, &temp_vec);
33-
vec_new_shape.push_back(temp_vec[0]);
26+
PADDLE_ENFORCE_EQ(tensor->dims() == phi::make_ddim({1}) ||
27+
tensor->dims() == phi::make_ddim({}),
28+
true,
29+
phi::errors::InvalidArgument(
30+
"The shape of dimension tensor should be [1] or [],"
31+
"but received d%.",
32+
tensor->dims()));
33+
if (tensor->dtype() == phi::DataType::INT64) {
34+
std::vector<int64_t> temp_vec(1);
35+
dev_ctx.Wait();
36+
TensorToVector(dev_ctx, *tensor, dev_ctx, &temp_vec);
37+
vec_new_shape.push_back(temp_vec[0]);
38+
} else if (tensor->dtype() == phi::DataType::INT32) {
39+
std::vector<int32_t> temp_vec(1);
40+
dev_ctx.Wait();
41+
TensorToVector(dev_ctx, *tensor, dev_ctx, &temp_vec);
42+
vec_new_shape.push_back(temp_vec[0]);
43+
}
3444
}
3545

3646
return vec_new_shape;
@@ -75,25 +85,14 @@ void InterpolateKernel(
7585
if (size_tensor && size_tensor->size() > 0) {
7686
// have SizeTensor
7787
VLOG(5) << "[Interp] get out_w and out_w from SizeTensor";
78-
auto list_new_shape_tensor = size_tensor.get();
79-
80-
if (list_new_shape_tensor.size() <= 2) {
81-
auto output_h =
82-
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[0]);
83-
auto output_w =
84-
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[1]);
85-
out_h = output_h[0];
86-
out_w = output_w[0];
88+
auto output_get = get_new_shape_mlu(dev_ctx, size_tensor.get());
89+
if (output_get.size() <= 2) {
90+
out_h = output_get[0];
91+
out_w = output_get[1];
8792
} else {
88-
auto output_d =
89-
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[0]);
90-
auto output_h =
91-
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[1]);
92-
auto output_w =
93-
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[2]);
94-
out_h = output_h[0];
95-
out_w = output_w[0];
96-
out_d = output_d[0];
93+
out_h = output_get[0];
94+
out_w = output_get[1];
95+
out_d = output_get[2];
9796
}
9897
} else if (out_size) {
9998
VLOG(5) << "[Interp] get out_w and out_w from OutSize";

backends/npu/kernels/funcs/npu_funcs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ inline void TensorToVector(const phi::CustomContext& ctx,
281281
AsyncMemCpyD2H(
282282
&device, static_cast<C_Stream>(ctx.stream()), dst_ptr, src_ptr, size);
283283
ctx.Wait();
284+
} else if (src_place.GetType() == phi::AllocationType::CPU) {
285+
std::memcpy(dst_ptr, src_ptr, size);
284286
} else {
285287
PADDLE_THROW(phi::errors::Unimplemented(
286288
"TensorToVector on %s is not supported.", src_place));

backends/npu/kernels/interpolate_kernel.cc

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,37 @@
1717
#include "kernels/funcs/slice_utils.h"
1818

1919
namespace custom_kernel {
20+
21+
inline std::vector<int> get_new_shape_npu(
22+
const phi::CustomContext& dev_ctx,
23+
const std::vector<const phi::DenseTensor*>& list_new_shape_tensor) {
24+
// get tensor from
25+
std::vector<int> vec_new_shape;
26+
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
27+
auto tensor = list_new_shape_tensor[i];
28+
PADDLE_ENFORCE_EQ(tensor->dims() == phi::make_ddim({1}) ||
29+
tensor->dims() == phi::make_ddim({}),
30+
true,
31+
phi::errors::InvalidArgument(
32+
"The shape of dimension tensor should be [1] or [],"
33+
"but received d%.",
34+
tensor->dims()));
35+
if (tensor->dtype() == phi::DataType::INT64) {
36+
std::vector<int64_t> temp_vec(1);
37+
dev_ctx.Wait();
38+
TensorToVector(dev_ctx, *tensor, dev_ctx, &temp_vec);
39+
vec_new_shape.push_back(temp_vec[0]);
40+
} else if (tensor->dtype() == phi::DataType::INT32) {
41+
std::vector<int32_t> temp_vec(1);
42+
dev_ctx.Wait();
43+
TensorToVector(dev_ctx, *tensor, dev_ctx, &temp_vec);
44+
vec_new_shape.push_back(temp_vec[0]);
45+
}
46+
}
47+
48+
return vec_new_shape;
49+
}
50+
2051
template <typename T, typename Context>
2152
void TransposeKernel(const Context& dev_ctx,
2253
const phi::DenseTensor& x,
@@ -798,13 +829,15 @@ void InterpolateKernel(
798829

799830
// Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
800831
if (size_tensor && size_tensor->size() > 0) {
801-
auto list_new_shape_tensor = size_tensor.get();
802-
auto output_h =
803-
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[0]);
804-
auto output_w =
805-
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[1]);
806-
out_h = output_h[0];
807-
out_w = output_w[0];
832+
auto output_get = get_new_shape_npu(dev_ctx, size_tensor.get());
833+
if (output_get.size() <= 2) {
834+
out_h = output_get[0];
835+
out_w = output_get[1];
836+
} else {
837+
out_h = output_get[0];
838+
out_w = output_get[1];
839+
out_d = output_get[2];
840+
}
808841
} else {
809842
if (scale_tensor) {
810843
auto scale_data =
@@ -983,13 +1016,9 @@ void InterpolateGradKernel(
9831016

9841017
// Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
9851018
if (size_tensor && size_tensor->size() > 0) {
986-
auto list_new_size_tensor = size_tensor.get();
987-
std::vector<int32_t> output_h(1);
988-
std::vector<int32_t> output_w(1);
989-
TensorToVector(dev_ctx, *(list_new_size_tensor[0]), dev_ctx, &output_h);
990-
TensorToVector(dev_ctx, *(list_new_size_tensor[1]), dev_ctx, &output_w);
991-
out_h = output_h[0];
992-
out_w = output_w[0];
1019+
auto output_get = get_new_shape_npu(dev_ctx, size_tensor.get());
1020+
out_h = output_get[0];
1021+
out_w = output_get[1];
9931022
} else if (out_size) {
9941023
auto out_size_data =
9951024
get_new_data_from_tensor<int>(dev_ctx, out_size.get_ptr());

0 commit comments

Comments
 (0)