Skip to content

Commit 3f0b97d

Browse files
committed
update tensorrt subgraph_util test=release/1.4
(cherry picked from commit bddb2cd)
1 parent 8877054 commit 3f0b97d

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@ void RenameAndGetOutputs(
7070
std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/>
7171
same_hierarchy_conv2d_num_map;
7272

73-
auto set_var_shape = [&](const std::string &arg_value) {
74-
auto arg_var_node = graph_var_map.find(arg_value);
73+
auto add_block_var = [&](const std::string &graph_arg,
74+
const std::string &block_arg) {
75+
auto arg_var_node = graph_var_map.find(graph_arg);
7576
PADDLE_ENFORCE(arg_var_node != graph_var_map.end());
76-
auto *var_t = block_desc->Var(arg_value);
77+
auto *var_t = block_desc->Var(block_arg);
7778
var_t->SetShape(arg_var_node->second->Var()->GetShape());
79+
var_t->SetDataType(arg_var_node->second->Var()->GetDataType());
7880
};
7981

8082
for (size_t index = 0; index < block_desc->OpSize(); ++index) {
@@ -99,15 +101,16 @@ void RenameAndGetOutputs(
99101
const std::string arg_value_with_id =
100102
arg_value + std::to_string(var2id[arg_value]);
101103

102-
bool is_var_in_graph = graph_var_map.count(arg_value);
103-
104104
if (input_names_with_id.count(arg_value_with_id)) {
105105
replaced_names.push_back(arg_value);
106+
if (graph_var_map.count(arg_value)) {
107+
add_block_var(arg_value, arg_value);
108+
}
106109
} else {
107110
replaced_names.push_back(arg_value_with_id);
108-
}
109-
if (is_var_in_graph) {
110-
set_var_shape(arg_value);
111+
if (graph_var_map.count(arg_value)) {
112+
add_block_var(arg_value, arg_value_with_id);
113+
}
111114
}
112115
}
113116
in_var->clear_arguments();
@@ -147,11 +150,9 @@ void RenameAndGetOutputs(
147150
const std::string arg_value_with_id =
148151
arg_value + std::to_string(var2id[arg_value]);
149152

150-
bool is_var_in_graph = graph_var_map.count(arg_value);
151-
if (is_var_in_graph) {
152-
set_var_shape(arg_value);
153+
if (graph_var_map.count(arg_value)) {
154+
add_block_var(arg_value, arg_value_with_id);
153155
}
154-
155156
if (output_names_with_id->count(arg_value_with_id)) {
156157
(*output_name_map)[arg_value] = arg_value_with_id;
157158
}

0 commit comments

Comments
 (0)