Skip to content

Commit 772746c

Browse files
[oneDNN] Fix to inplace pass (#24442) (#25182)
1 parent ddc7f39 commit 772746c

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
3232
platform::errors::InvalidArgument(
3333
"Pointer to graph argument should not be NULL."));
3434
std::unordered_map<std::string, std::string> original_output_names;
35+
std::unordered_set<std::string> inplaced_vars;
3536
GraphPatternDetector gpd;
3637
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
3738
"mkldnn_inplace"};
@@ -94,6 +95,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
9495
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
9596
"be an input to multiple operators";
9697
return;
98+
} else {
99+
// We will prevent in-place when
100+
// input is used in other part of graph, unless it was a result of
101+
// inplacing
102+
// Allow to next op out reuse inpuit var, as this is the same chaing
103+
if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) {
104+
for (const Node* n : graph->Nodes()) {
105+
if ((n->id() != current_op_in->id()) &&
106+
(n->id() != next_op_out->id()) &&
107+
(n->Name() == current_op_in->Name())) {
108+
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of "
109+
"graph ";
110+
return;
111+
}
112+
}
113+
}
97114
}
98115

99116
// If this op was alrady inplaced in previous pass placements
@@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
132149
auto out_name = in_to_outs.begin()->second;
133150
current_op->Op()->SetOutput(
134151
out_name, std::vector<std::string>({current_op_out->Name()}));
152+
// Record var name
153+
inplaced_vars.insert(current_op_out->Name());
135154

136155
// If next op in a line is doing inplace
137156
// then we need to update its output as well

0 commit comments

Comments
 (0)