@@ -287,6 +287,84 @@ inline aclTensor* ConvertType(const phi::DenseTensor& at_tensor) {
287
287
return acl_tensor;
288
288
}
289
289
290
+ template <typename T>
291
+ inline aclTensor* ConvertType (const std::pair<T*,
292
+ std::vector<int64_t >>& tp_tensor) {
293
+ // Ptr ConvertType to aclTensor
294
+ static const auto aclCreateTensor = GET_OP_API_FUNC (aclCreateTensor);
295
+ if (aclCreateTensor == nullptr ) {
296
+ LOG (WARNING) << " aclCreateTensor execute failed, return nullptr." ;
297
+ return nullptr ;
298
+ }
299
+ if (tp_tensor.first == nullptr || tp_tensor.second .size () < 1 ) {
300
+ LOG (WARNING) << " received nullptr or empty tensor, return nullptr." ;
301
+ return nullptr ;
302
+ }
303
+ auto at_tensor_dtype = phi::DataType::FLOAT16;
304
+ if (std::is_same<T, const phi::float16>::value ||
305
+ std::is_same<T, phi::float16>::value) {
306
+ at_tensor_dtype = phi::DataType::FLOAT16;
307
+ } else if (std::is_same<T, const int64_t >::value ||
308
+ std::is_same<T, int64_t >::value) {
309
+ at_tensor_dtype = phi::DataType::INT64;
310
+ } else if (std::is_same<T, const int32_t >::value ||
311
+ std::is_same<T, int32_t >::value) {
312
+ at_tensor_dtype = phi::DataType::INT32;
313
+ } else if (std::is_same<T, const int8_t >::value ||
314
+ std::is_same<T, int8_t >::value) {
315
+ at_tensor_dtype = phi::DataType::INT8;
316
+ }
317
+ auto acl_data_type = ConvertToNpuDtype (at_tensor_dtype);
318
+ const auto dimNum =
319
+ tp_tensor.second .size () == 0 ? 1 : tp_tensor.second .size ();
320
+ std::vector<int64_t > storageDims (dimNum - 1 );
321
+ int64_t tp_tensor_numel = 1 ;
322
+ for (auto &num : tp_tensor.second ) {
323
+ tp_tensor_numel *= num;
324
+ }
325
+ if (acl_data_type != ACL_STRING) {
326
+ storageDims.push_back (tp_tensor_numel * sizeof (at_tensor_dtype));
327
+ }
328
+ aclFormat format = ACL_FORMAT_ND;
329
+ switch (dimNum) {
330
+ case 4 :
331
+ format = ACL_FORMAT_NCHW;
332
+ break ;
333
+ case 5 :
334
+ format = ACL_FORMAT_NCDHW;
335
+ break ;
336
+ default :
337
+ format = ACL_FORMAT_ND;
338
+ }
339
+
340
+ std::vector<int64_t > origin_dims;
341
+ std::vector<int64_t > origin_strides;
342
+ if (tp_tensor.second .size () == 2 ) {
343
+ origin_dims = phi::vectorize ({tp_tensor.second [0 ], tp_tensor.second [1 ]});
344
+ origin_strides = phi::vectorize ({tp_tensor.second [1 ], 1 });
345
+ } else if (tp_tensor.second .size () == 1 ) {
346
+ origin_dims = phi::vectorize ({tp_tensor.second [0 ]});
347
+ origin_strides = phi::vectorize ({1 });
348
+ } else {
349
+ PADDLE_ENFORCE_GE (
350
+ tp_tensor.second .size (), 2 ,
351
+ phi::errors::InvalidArgument (" Only support 1-d and 2-d input" ));
352
+ return nullptr ;
353
+ }
354
+ auto acl_tensor = aclCreateTensor (origin_dims.data (),
355
+ origin_dims.size (),
356
+ acl_data_type,
357
+ origin_strides.data (),
358
+ 0 ,
359
+ format,
360
+ origin_dims.data (),
361
+ storageDims.size (),
362
+ const_cast <void *>(
363
+ static_cast <const void *>(
364
+ tp_tensor.first )));
365
+ return acl_tensor;
366
+ }
367
+
290
368
inline aclTensorList *ConvertType (
291
369
const std::vector<const phi::DenseTensor*> &phi_tensor_list) {
292
370
static const auto aclCreateTensorList = GET_OP_API_FUNC (aclCreateTensorList);
0 commit comments