@@ -70,11 +70,13 @@ void RenameAndGetOutputs(
70
70
std::unordered_map<std::string /* name*/ , int /* ITensor_quote_num*/ >
71
71
same_hierarchy_conv2d_num_map;
72
72
73
- auto set_var_shape = [&](const std::string &arg_value) {
74
- auto arg_var_node = graph_var_map.find (arg_value);
73
+ auto add_block_var = [&](const std::string &graph_arg,
74
+ const std::string &block_arg) {
75
+ auto arg_var_node = graph_var_map.find (graph_arg);
75
76
PADDLE_ENFORCE (arg_var_node != graph_var_map.end ());
76
- auto *var_t = block_desc->Var (arg_value );
77
+ auto *var_t = block_desc->Var (block_arg );
77
78
var_t ->SetShape (arg_var_node->second ->Var ()->GetShape ());
79
+ var_t ->SetDataType (arg_var_node->second ->Var ()->GetDataType ());
78
80
};
79
81
80
82
for (size_t index = 0 ; index < block_desc->OpSize (); ++index) {
@@ -99,15 +101,16 @@ void RenameAndGetOutputs(
99
101
const std::string arg_value_with_id =
100
102
arg_value + std::to_string (var2id[arg_value]);
101
103
102
- bool is_var_in_graph = graph_var_map.count (arg_value);
103
-
104
104
if (input_names_with_id.count (arg_value_with_id)) {
105
105
replaced_names.push_back (arg_value);
106
+ if (graph_var_map.count (arg_value)) {
107
+ add_block_var (arg_value, arg_value);
108
+ }
106
109
} else {
107
110
replaced_names.push_back (arg_value_with_id);
108
- }
109
- if (is_var_in_graph) {
110
- set_var_shape (arg_value);
111
+ if (graph_var_map. count (arg_value)) {
112
+ add_block_var (arg_value, arg_value_with_id);
113
+ }
111
114
}
112
115
}
113
116
in_var->clear_arguments ();
@@ -147,11 +150,9 @@ void RenameAndGetOutputs(
147
150
const std::string arg_value_with_id =
148
151
arg_value + std::to_string (var2id[arg_value]);
149
152
150
- bool is_var_in_graph = graph_var_map.count (arg_value);
151
- if (is_var_in_graph) {
152
- set_var_shape (arg_value);
153
+ if (graph_var_map.count (arg_value)) {
154
+ add_block_var (arg_value, arg_value_with_id);
153
155
}
154
-
155
156
if (output_names_with_id->count (arg_value_with_id)) {
156
157
(*output_name_map)[arg_value] = arg_value_with_id;
157
158
}
0 commit comments