Skip to content

Commit e36cd94

Browse files
authored
fix GetInputTypes() & GetOutputTypes() when using .json model (#74358)
* fix GetInputTypes() & GetOutputTypes() when using .json model * fix GetInputTypes() & GetOutputTypes() when using .json model * Update analysis_predictor.cc * Update analysis_predictor.cc * fix GetInputTypes() & GetOutputTypes() when using .json model * fix GetInputTypes() & GetOutputTypes() when using .json model
1 parent 7404342 commit e36cd94

File tree

1 file changed

+66
-52
lines changed

1 file changed

+66
-52
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 66 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,34 +2500,41 @@ std::map<std::string, paddle_infer::DataType>
25002500
AnalysisPredictor::GetInputTypes() {
25012501
std::map<std::string, paddle_infer::DataType> input_type;
25022502
std::vector<std::string> names = GetInputNames();
2503-
for (const auto &name : names) {
2504-
auto *var = inference_program_->Block(0).FindVar(name);
2505-
PADDLE_ENFORCE_NOT_NULL(
2506-
var,
2507-
common::errors::PreconditionNotMet(
2508-
"Input %s does not exist inference_program_.", name));
2509-
auto dtype = var->GetDataType();
2510-
if (dtype == paddle::framework::proto::VarType::FP32) {
2511-
input_type[name] = paddle_infer::DataType::FLOAT32;
2512-
} else if (dtype == paddle::framework::proto::VarType::FP16) {
2513-
input_type[name] = paddle_infer::DataType::FLOAT16;
2514-
} else if (dtype == paddle::framework::proto::VarType::BF16) {
2515-
input_type[name] = paddle_infer::DataType::BFLOAT16;
2516-
} else if (dtype == paddle::framework::proto::VarType::INT64) {
2517-
input_type[name] = paddle_infer::DataType::INT64;
2518-
} else if (dtype == paddle::framework::proto::VarType::INT32) {
2519-
input_type[name] = paddle_infer::DataType::INT32;
2520-
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
2521-
input_type[name] = paddle_infer::DataType::UINT8;
2522-
} else if (dtype == paddle::framework::proto::VarType::INT8) {
2523-
input_type[name] = paddle_infer::DataType::INT8;
2524-
} else if (dtype == paddle::framework::proto::VarType::FP64) {
2525-
input_type[name] = paddle_infer::DataType::FLOAT64;
2526-
} else if (dtype == paddle::framework::proto::VarType::BOOL) {
2527-
input_type[name] = paddle_infer::DataType::BOOL;
2528-
} else {
2529-
PADDLE_THROW(common::errors::Unimplemented(
2530-
"Unsupported data type `%s` when get input dtype ", dtype));
2503+
if (load_pir_model_) {
2504+
for (const auto &name : names) {
2505+
auto tensor = GetInputTensor(name);
2506+
input_type[name] = tensor->type();
2507+
}
2508+
} else {
2509+
for (const auto &name : names) {
2510+
auto *var = inference_program_->Block(0).FindVar(name);
2511+
PADDLE_ENFORCE_NOT_NULL(
2512+
var,
2513+
common::errors::PreconditionNotMet(
2514+
"Input %s does not exist inference_program_.", name));
2515+
auto dtype = var->GetDataType();
2516+
if (dtype == paddle::framework::proto::VarType::FP32) {
2517+
input_type[name] = paddle_infer::DataType::FLOAT32;
2518+
} else if (dtype == paddle::framework::proto::VarType::FP16) {
2519+
input_type[name] = paddle_infer::DataType::FLOAT16;
2520+
} else if (dtype == paddle::framework::proto::VarType::BF16) {
2521+
input_type[name] = paddle_infer::DataType::BFLOAT16;
2522+
} else if (dtype == paddle::framework::proto::VarType::INT64) {
2523+
input_type[name] = paddle_infer::DataType::INT64;
2524+
} else if (dtype == paddle::framework::proto::VarType::INT32) {
2525+
input_type[name] = paddle_infer::DataType::INT32;
2526+
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
2527+
input_type[name] = paddle_infer::DataType::UINT8;
2528+
} else if (dtype == paddle::framework::proto::VarType::INT8) {
2529+
input_type[name] = paddle_infer::DataType::INT8;
2530+
} else if (dtype == paddle::framework::proto::VarType::FP64) {
2531+
input_type[name] = paddle_infer::DataType::FLOAT64;
2532+
} else if (dtype == paddle::framework::proto::VarType::BOOL) {
2533+
input_type[name] = paddle_infer::DataType::BOOL;
2534+
} else {
2535+
PADDLE_THROW(common::errors::Unimplemented(
2536+
"Unsupported data type `%s` when get input dtype ", dtype));
2537+
}
25312538
}
25322539
}
25332540
return input_type;
@@ -2562,30 +2569,37 @@ std::map<std::string, paddle_infer::DataType>
25622569
AnalysisPredictor::GetOutputTypes() {
25632570
std::map<std::string, paddle_infer::DataType> output_type;
25642571
std::vector<std::string> names = GetOutputNames();
2565-
for (const auto &name : names) {
2566-
auto *var = inference_program_->Block(0).FindVar(name);
2567-
PADDLE_ENFORCE_NOT_NULL(
2568-
var,
2569-
common::errors::PreconditionNotMet(
2570-
"Output %s does not exist inference_program_.", name));
2571-
auto dtype = var->GetDataType();
2572-
if (dtype == paddle::framework::proto::VarType::FP32) {
2573-
output_type[name] = paddle_infer::DataType::FLOAT32;
2574-
} else if (dtype == paddle::framework::proto::VarType::FP16) {
2575-
output_type[name] = paddle_infer::DataType::FLOAT16;
2576-
} else if (dtype == paddle::framework::proto::VarType::BF16) {
2577-
output_type[name] = paddle_infer::DataType::BFLOAT16;
2578-
} else if (dtype == paddle::framework::proto::VarType::INT64) {
2579-
output_type[name] = paddle_infer::DataType::INT64;
2580-
} else if (dtype == paddle::framework::proto::VarType::INT32) {
2581-
output_type[name] = paddle_infer::DataType::INT32;
2582-
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
2583-
output_type[name] = paddle_infer::DataType::UINT8;
2584-
} else if (dtype == paddle::framework::proto::VarType::INT8) {
2585-
output_type[name] = paddle_infer::DataType::INT8;
2586-
} else {
2587-
PADDLE_THROW(common::errors::Unimplemented(
2588-
"Unsupported data type `%s` when get output dtype ", dtype));
2572+
if (load_pir_model_) {
2573+
for (const auto &name : names) {
2574+
auto tensor = GetOutputTensor(name);
2575+
output_type[name] = tensor->type();
2576+
}
2577+
} else {
2578+
for (const auto &name : names) {
2579+
auto *var = inference_program_->Block(0).FindVar(name);
2580+
PADDLE_ENFORCE_NOT_NULL(
2581+
var,
2582+
common::errors::PreconditionNotMet(
2583+
"Output %s does not exist inference_program_.", name));
2584+
auto dtype = var->GetDataType();
2585+
if (dtype == paddle::framework::proto::VarType::FP32) {
2586+
output_type[name] = paddle_infer::DataType::FLOAT32;
2587+
} else if (dtype == paddle::framework::proto::VarType::FP16) {
2588+
output_type[name] = paddle_infer::DataType::FLOAT16;
2589+
} else if (dtype == paddle::framework::proto::VarType::BF16) {
2590+
output_type[name] = paddle_infer::DataType::BFLOAT16;
2591+
} else if (dtype == paddle::framework::proto::VarType::INT64) {
2592+
output_type[name] = paddle_infer::DataType::INT64;
2593+
} else if (dtype == paddle::framework::proto::VarType::INT32) {
2594+
output_type[name] = paddle_infer::DataType::INT32;
2595+
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
2596+
output_type[name] = paddle_infer::DataType::UINT8;
2597+
} else if (dtype == paddle::framework::proto::VarType::INT8) {
2598+
output_type[name] = paddle_infer::DataType::INT8;
2599+
} else {
2600+
PADDLE_THROW(common::errors::Unimplemented(
2601+
"Unsupported data type `%s` when get output dtype ", dtype));
2602+
}
25892603
}
25902604
}
25912605
return output_type;

0 commit comments

Comments
 (0)