@@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
32
32
platform::errors::InvalidArgument (
33
33
" Pointer to graph argument should not be NULL." ));
34
34
std::unordered_map<std::string, std::string> original_output_names;
35
+ std::unordered_set<std::string> inplaced_vars;
35
36
GraphPatternDetector gpd;
36
37
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern (),
37
38
" mkldnn_inplace" };
@@ -94,6 +95,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
94
95
VLOG (3 ) << " DNNL in-place pass FAIL: in-place var cannot "
95
96
" be an input to multiple operators" ;
96
97
return ;
98
+ } else {
99
+ // We will prevent in-place when
100
+ // input is used in other part of graph, unless it was a result of
101
+ // inplacing
102
+ // Allow to next op out reuse inpuit var, as this is the same chaing
103
+ if (inplaced_vars.find (current_op_in->Name ()) == inplaced_vars.end ()) {
104
+ for (const Node* n : graph->Nodes ()) {
105
+ if ((n->id () != current_op_in->id ()) &&
106
+ (n->id () != next_op_out->id ()) &&
107
+ (n->Name () == current_op_in->Name ())) {
108
+ VLOG (3 ) << " DNNL in-place pass FAIL var used in diffrent part of "
109
+ " graph " ;
110
+ return ;
111
+ }
112
+ }
113
+ }
97
114
}
98
115
99
116
// If this op was alrady inplaced in previous pass placements
@@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
132
149
auto out_name = in_to_outs.begin ()->second ;
133
150
current_op->Op ()->SetOutput (
134
151
out_name, std::vector<std::string>({current_op_out->Name ()}));
152
+ // Record var name
153
+ inplaced_vars.insert (current_op_out->Name ());
135
154
136
155
// If next op in a line is doing inplace
137
156
// then we need to update its output as well
0 commit comments