@@ -2500,34 +2500,41 @@ std::map<std::string, paddle_infer::DataType>
2500
2500
AnalysisPredictor::GetInputTypes () {
2501
2501
std::map<std::string, paddle_infer::DataType> input_type;
2502
2502
std::vector<std::string> names = GetInputNames ();
2503
- for (const auto &name : names) {
2504
- auto *var = inference_program_->Block (0 ).FindVar (name);
2505
- PADDLE_ENFORCE_NOT_NULL (
2506
- var,
2507
- common::errors::PreconditionNotMet (
2508
- " Input %s does not exist inference_program_." , name));
2509
- auto dtype = var->GetDataType ();
2510
- if (dtype == paddle::framework::proto::VarType::FP32) {
2511
- input_type[name] = paddle_infer::DataType::FLOAT32;
2512
- } else if (dtype == paddle::framework::proto::VarType::FP16) {
2513
- input_type[name] = paddle_infer::DataType::FLOAT16;
2514
- } else if (dtype == paddle::framework::proto::VarType::BF16) {
2515
- input_type[name] = paddle_infer::DataType::BFLOAT16;
2516
- } else if (dtype == paddle::framework::proto::VarType::INT64) {
2517
- input_type[name] = paddle_infer::DataType::INT64;
2518
- } else if (dtype == paddle::framework::proto::VarType::INT32) {
2519
- input_type[name] = paddle_infer::DataType::INT32;
2520
- } else if (dtype == paddle::framework::proto::VarType::UINT8) {
2521
- input_type[name] = paddle_infer::DataType::UINT8;
2522
- } else if (dtype == paddle::framework::proto::VarType::INT8) {
2523
- input_type[name] = paddle_infer::DataType::INT8;
2524
- } else if (dtype == paddle::framework::proto::VarType::FP64) {
2525
- input_type[name] = paddle_infer::DataType::FLOAT64;
2526
- } else if (dtype == paddle::framework::proto::VarType::BOOL) {
2527
- input_type[name] = paddle_infer::DataType::BOOL;
2528
- } else {
2529
- PADDLE_THROW (common::errors::Unimplemented (
2530
- " Unsupported data type `%s` when get input dtype " , dtype));
2503
+ if (load_pir_model_) {
2504
+ for (const auto &name : names) {
2505
+ auto tensor = GetInputTensor (name);
2506
+ input_type[name] = tensor->type ();
2507
+ }
2508
+ } else {
2509
+ for (const auto &name : names) {
2510
+ auto *var = inference_program_->Block (0 ).FindVar (name);
2511
+ PADDLE_ENFORCE_NOT_NULL (
2512
+ var,
2513
+ common::errors::PreconditionNotMet (
2514
+ " Input %s does not exist inference_program_." , name));
2515
+ auto dtype = var->GetDataType ();
2516
+ if (dtype == paddle::framework::proto::VarType::FP32) {
2517
+ input_type[name] = paddle_infer::DataType::FLOAT32;
2518
+ } else if (dtype == paddle::framework::proto::VarType::FP16) {
2519
+ input_type[name] = paddle_infer::DataType::FLOAT16;
2520
+ } else if (dtype == paddle::framework::proto::VarType::BF16) {
2521
+ input_type[name] = paddle_infer::DataType::BFLOAT16;
2522
+ } else if (dtype == paddle::framework::proto::VarType::INT64) {
2523
+ input_type[name] = paddle_infer::DataType::INT64;
2524
+ } else if (dtype == paddle::framework::proto::VarType::INT32) {
2525
+ input_type[name] = paddle_infer::DataType::INT32;
2526
+ } else if (dtype == paddle::framework::proto::VarType::UINT8) {
2527
+ input_type[name] = paddle_infer::DataType::UINT8;
2528
+ } else if (dtype == paddle::framework::proto::VarType::INT8) {
2529
+ input_type[name] = paddle_infer::DataType::INT8;
2530
+ } else if (dtype == paddle::framework::proto::VarType::FP64) {
2531
+ input_type[name] = paddle_infer::DataType::FLOAT64;
2532
+ } else if (dtype == paddle::framework::proto::VarType::BOOL) {
2533
+ input_type[name] = paddle_infer::DataType::BOOL;
2534
+ } else {
2535
+ PADDLE_THROW (common::errors::Unimplemented (
2536
+ " Unsupported data type `%s` when get input dtype " , dtype));
2537
+ }
2531
2538
}
2532
2539
}
2533
2540
return input_type;
@@ -2562,30 +2569,37 @@ std::map<std::string, paddle_infer::DataType>
2562
2569
AnalysisPredictor::GetOutputTypes () {
2563
2570
std::map<std::string, paddle_infer::DataType> output_type;
2564
2571
std::vector<std::string> names = GetOutputNames ();
2565
- for (const auto &name : names) {
2566
- auto *var = inference_program_->Block (0 ).FindVar (name);
2567
- PADDLE_ENFORCE_NOT_NULL (
2568
- var,
2569
- common::errors::PreconditionNotMet (
2570
- " Output %s does not exist inference_program_." , name));
2571
- auto dtype = var->GetDataType ();
2572
- if (dtype == paddle::framework::proto::VarType::FP32) {
2573
- output_type[name] = paddle_infer::DataType::FLOAT32;
2574
- } else if (dtype == paddle::framework::proto::VarType::FP16) {
2575
- output_type[name] = paddle_infer::DataType::FLOAT16;
2576
- } else if (dtype == paddle::framework::proto::VarType::BF16) {
2577
- output_type[name] = paddle_infer::DataType::BFLOAT16;
2578
- } else if (dtype == paddle::framework::proto::VarType::INT64) {
2579
- output_type[name] = paddle_infer::DataType::INT64;
2580
- } else if (dtype == paddle::framework::proto::VarType::INT32) {
2581
- output_type[name] = paddle_infer::DataType::INT32;
2582
- } else if (dtype == paddle::framework::proto::VarType::UINT8) {
2583
- output_type[name] = paddle_infer::DataType::UINT8;
2584
- } else if (dtype == paddle::framework::proto::VarType::INT8) {
2585
- output_type[name] = paddle_infer::DataType::INT8;
2586
- } else {
2587
- PADDLE_THROW (common::errors::Unimplemented (
2588
- " Unsupported data type `%s` when get output dtype " , dtype));
2572
+ if (load_pir_model_) {
2573
+ for (const auto &name : names) {
2574
+ auto tensor = GetOutputTensor (name);
2575
+ output_type[name] = tensor->type ();
2576
+ }
2577
+ } else {
2578
+ for (const auto &name : names) {
2579
+ auto *var = inference_program_->Block (0 ).FindVar (name);
2580
+ PADDLE_ENFORCE_NOT_NULL (
2581
+ var,
2582
+ common::errors::PreconditionNotMet (
2583
+ " Output %s does not exist inference_program_." , name));
2584
+ auto dtype = var->GetDataType ();
2585
+ if (dtype == paddle::framework::proto::VarType::FP32) {
2586
+ output_type[name] = paddle_infer::DataType::FLOAT32;
2587
+ } else if (dtype == paddle::framework::proto::VarType::FP16) {
2588
+ output_type[name] = paddle_infer::DataType::FLOAT16;
2589
+ } else if (dtype == paddle::framework::proto::VarType::BF16) {
2590
+ output_type[name] = paddle_infer::DataType::BFLOAT16;
2591
+ } else if (dtype == paddle::framework::proto::VarType::INT64) {
2592
+ output_type[name] = paddle_infer::DataType::INT64;
2593
+ } else if (dtype == paddle::framework::proto::VarType::INT32) {
2594
+ output_type[name] = paddle_infer::DataType::INT32;
2595
+ } else if (dtype == paddle::framework::proto::VarType::UINT8) {
2596
+ output_type[name] = paddle_infer::DataType::UINT8;
2597
+ } else if (dtype == paddle::framework::proto::VarType::INT8) {
2598
+ output_type[name] = paddle_infer::DataType::INT8;
2599
+ } else {
2600
+ PADDLE_THROW (common::errors::Unimplemented (
2601
+ " Unsupported data type `%s` when get output dtype " , dtype));
2602
+ }
2589
2603
}
2590
2604
}
2591
2605
return output_type;
0 commit comments