|
17 | 17 | #include "kernels/funcs/slice_utils.h"
|
18 | 18 |
|
19 | 19 | namespace custom_kernel {
|
| 20 | + |
| 21 | +inline std::vector<int> get_new_shape_npu( |
| 22 | + const phi::CustomContext& dev_ctx, |
| 23 | + const std::vector<const phi::DenseTensor*>& list_new_shape_tensor) { |
| 24 | + // get tensor from |
| 25 | + std::vector<int> vec_new_shape; |
| 26 | + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { |
| 27 | + auto tensor = list_new_shape_tensor[i]; |
| 28 | + PADDLE_ENFORCE_EQ(tensor->dims() == phi::make_ddim({1}) || |
| 29 | + tensor->dims() == phi::make_ddim({}), |
| 30 | + true, |
| 31 | + phi::errors::InvalidArgument( |
| 32 | + "The shape of dimension tensor should be [1] or []," |
| 33 | + "but received d%.", |
| 34 | + tensor->dims())); |
| 35 | + if (tensor->dtype() == phi::DataType::INT64) { |
| 36 | + std::vector<int64_t> temp_vec(1); |
| 37 | + dev_ctx.Wait(); |
| 38 | + TensorToVector(dev_ctx, *tensor, dev_ctx, &temp_vec); |
| 39 | + vec_new_shape.push_back(temp_vec[0]); |
| 40 | + } else if (tensor->dtype() == phi::DataType::INT32) { |
| 41 | + std::vector<int32_t> temp_vec(1); |
| 42 | + dev_ctx.Wait(); |
| 43 | + TensorToVector(dev_ctx, *tensor, dev_ctx, &temp_vec); |
| 44 | + vec_new_shape.push_back(temp_vec[0]); |
| 45 | + } |
| 46 | + } |
| 47 | + |
| 48 | + return vec_new_shape; |
| 49 | +} |
| 50 | + |
20 | 51 | template <typename T, typename Context>
|
21 | 52 | void TransposeKernel(const Context& dev_ctx,
|
22 | 53 | const phi::DenseTensor& x,
|
@@ -798,13 +829,15 @@ void InterpolateKernel(
|
798 | 829 |
|
799 | 830 | // Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
|
800 | 831 | if (size_tensor && size_tensor->size() > 0) {
|
801 |
| - auto list_new_shape_tensor = size_tensor.get(); |
802 |
| - auto output_h = |
803 |
| - get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[0]); |
804 |
| - auto output_w = |
805 |
| - get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[1]); |
806 |
| - out_h = output_h[0]; |
807 |
| - out_w = output_w[0]; |
| 832 | + auto output_get = get_new_shape_npu(dev_ctx, size_tensor.get()); |
| 833 | + if (output_get.size() <= 2) { |
| 834 | + out_h = output_get[0]; |
| 835 | + out_w = output_get[1]; |
| 836 | + } else { |
| 837 | + out_h = output_get[0]; |
| 838 | + out_w = output_get[1]; |
| 839 | + out_d = output_get[2]; |
| 840 | + } |
808 | 841 | } else {
|
809 | 842 | if (scale_tensor) {
|
810 | 843 | auto scale_data =
|
@@ -983,13 +1016,9 @@ void InterpolateGradKernel(
|
983 | 1016 |
|
984 | 1017 | // Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
|
985 | 1018 | if (size_tensor && size_tensor->size() > 0) {
|
986 |
| - auto list_new_size_tensor = size_tensor.get(); |
987 |
| - std::vector<int32_t> output_h(1); |
988 |
| - std::vector<int32_t> output_w(1); |
989 |
| - TensorToVector(dev_ctx, *(list_new_size_tensor[0]), dev_ctx, &output_h); |
990 |
| - TensorToVector(dev_ctx, *(list_new_size_tensor[1]), dev_ctx, &output_w); |
991 |
| - out_h = output_h[0]; |
992 |
| - out_w = output_w[0]; |
| 1019 | + auto output_get = get_new_shape_npu(dev_ctx, size_tensor.get()); |
| 1020 | + out_h = output_get[0]; |
| 1021 | + out_w = output_get[1]; |
993 | 1022 | } else if (out_size) {
|
994 | 1023 | auto out_size_data =
|
995 | 1024 | get_new_data_from_tensor<int>(dev_ctx, out_size.get_ptr());
|
|
0 commit comments