|
14 | 14 |
|
15 | 15 | #include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
16 | 16 | #include <functional>
|
17 |
| -#include <utility> |
| 17 | +#include <list> |
| 18 | +#include <map> |
| 19 | +#include <tuple> |
18 | 20 |
|
19 | 21 | #include "paddle/fluid/framework/ir/graph_traits.h"
|
20 | 22 |
|
21 | 23 | namespace paddle {
|
22 | 24 | namespace framework {
|
23 | 25 | namespace ir {
|
24 |
| -namespace { |
25 | 26 |
|
26 | 27 | // The function keeps the graph consistent by replacing
|
27 | 28 | // a node 'from' in the set of inputs nodes
|
@@ -51,104 +52,179 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
51 | 52 | }
|
52 | 53 | }
|
53 | 54 | }
|
54 |
| -} // namespace |
55 |
| -using graph_ptr = std::unique_ptr<ir::Graph>; |
56 | 55 |
|
57 |
| -graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { |
58 |
| - FusePassBase::Init(name_scope_, graph.get()); |
| 56 | +bool IsReachable(ir::Graph* graph, Node* from, Node* to) { |
| 57 | + auto find_node = [](ir::Graph* graph, const Node* node) -> Node* { |
| 58 | + for (auto n : graph->Nodes()) { |
| 59 | + if (n == node) { |
| 60 | + return n; |
| 61 | + } |
| 62 | + } |
59 | 63 |
|
60 |
| - GraphPatternDetector gpd; |
61 |
| - auto pattern = gpd.mutable_pattern(); |
| 64 | + return nullptr; |
| 65 | + }; |
62 | 66 |
|
63 |
| - patterns::Conv conv_pattern{pattern, name_scope_}; |
64 |
| - auto conv_output = conv_pattern(); |
| 67 | + if (from == to) { |
| 68 | + return true; |
| 69 | + } |
65 | 70 |
|
66 |
| - patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; |
67 |
| - elementwise_add_pattern(conv_output); |
| 71 | + std::map<Node*, bool> visited; |
68 | 72 |
|
69 |
| - conv_output->AsIntermediate(); |
| 73 | + for (auto& node : GraphTraits::DFS(*graph)) { |
| 74 | + visited[&node] = false; |
| 75 | + } |
70 | 76 |
|
71 |
| - auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> { |
72 |
| - auto bias_input_names = conv_op.Op()->Inputs(); |
73 |
| - auto bias_it = bias_input_names.find("Bias"); |
74 |
| - |
75 |
| - if (bias_it != std::end(bias_input_names)) { |
76 |
| - bool has_bias = !bias_it->second.empty(); |
77 |
| - |
78 |
| - if (has_bias) { |
79 |
| - auto conv_bias_names = bias_it->second; |
80 |
| - auto conv_bias_names_it = |
81 |
| - std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs), |
82 |
| - [&conv_bias_names](Node* n) -> bool { |
83 |
| - return n->Name() == conv_bias_names[0]; |
84 |
| - }); |
85 |
| - return std::make_pair(has_bias, *conv_bias_names_it); |
86 |
| - } |
87 |
| - } |
| 77 | + visited[from] = true; |
88 | 78 |
|
89 |
| - return std::make_pair(false, nullptr); |
90 |
| - }; |
| 79 | + std::list<Node*> queue; |
| 80 | + queue.push_back(from); |
91 | 81 |
|
92 |
| - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, |
93 |
| - Graph* g) { |
94 |
| - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); |
95 |
| - GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); |
96 |
| - GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); |
97 |
| - GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); |
98 |
| - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, |
99 |
| - elementwise_add_pattern); |
100 |
| - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, |
101 |
| - elementwise_add_pattern); |
102 |
| - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, |
103 |
| - elementwise_add_pattern); |
| 82 | + while (!queue.empty()) { |
| 83 | + auto cur = find_node(graph, queue.front()); |
| 84 | + queue.pop_front(); |
104 | 85 |
|
105 |
| - if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; |
| 86 | + if (!cur) return false; |
106 | 87 |
|
107 |
| - OpDesc op_desc; |
108 |
| - op_desc.SetType("conv2d"); |
| 88 | + for (auto n : cur->outputs) { |
| 89 | + if (n == to) { |
| 90 | + return true; |
| 91 | + } |
109 | 92 |
|
110 |
| - op_desc.SetInput("Input", {conv_input->Name()}); |
111 |
| - op_desc.SetInput("Filter", {conv_filter->Name()}); |
112 |
| - op_desc.SetInput("ResidualData", {elementwise_add_x->Name()}); |
113 |
| - op_desc.SetOutput("Output", {conv_output->Name()}); |
| 93 | + if (!visited[n]) { |
| 94 | + visited[n] = true; |
| 95 | + queue.push_back(n); |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | + return false; |
| 100 | +} |
114 | 101 |
|
115 |
| - bool has_bias; |
116 |
| - Node* conv_bias; |
| 102 | +std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias( |
| 103 | + const Node& op) const { |
| 104 | + auto bias_input_names = op.Op()->Inputs(); |
| 105 | + auto bias_it = bias_input_names.find("Bias"); |
117 | 106 |
|
118 |
| - std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op); |
| 107 | + if (bias_it != std::end(bias_input_names)) { |
| 108 | + bool has_bias = !bias_it->second.empty(); |
119 | 109 |
|
120 | 110 | if (has_bias) {
|
121 |
| - op_desc.SetInput("Bias", {conv_bias->Name()}); |
| 111 | + auto bias_names = bias_it->second; |
| 112 | + auto bias_names_it = |
| 113 | + std::find_if(std::begin(op.inputs), std::end(op.inputs), |
| 114 | + [&bias_names](Node* n) -> bool { |
| 115 | + return n->Name() == bias_names[0]; |
| 116 | + }); |
| 117 | + return std::make_pair(has_bias, *bias_names_it); |
122 | 118 | }
|
| 119 | + } |
123 | 120 |
|
124 |
| - for (const auto& attr : conv_op->Op()->GetAttrMap()) { |
125 |
| - op_desc.SetAttr(attr.first, attr.second); |
126 |
| - } |
| 121 | + return std::make_pair(false, nullptr); |
| 122 | +} |
127 | 123 |
|
128 |
| - op_desc.SetAttr("fuse_residual_connection", true); |
| 124 | +graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( |
| 125 | + const std::string& name_scope_, graph_ptr graph) const { |
| 126 | + GraphPatternDetector gpd; |
| 127 | + auto pattern = gpd.mutable_pattern(); |
129 | 128 |
|
130 |
| - auto fused_conv_op = g->CreateOpNode(&op_desc); |
| 129 | + patterns::Conv conv_pattern{pattern, name_scope_}; |
| 130 | + auto conv_output = conv_pattern(); |
131 | 131 |
|
132 |
| - IR_NODE_LINK_TO(conv_input, fused_conv_op); |
133 |
| - IR_NODE_LINK_TO(conv_filter, fused_conv_op); |
134 |
| - IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op); |
135 |
| - IR_NODE_LINK_TO(fused_conv_op, conv_output); |
| 132 | + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; |
| 133 | + elementwise_add_pattern( |
| 134 | + conv_output, |
| 135 | + pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); |
| 136 | + conv_output->AsIntermediate(); |
136 | 137 |
|
137 |
| - if (has_bias) { |
138 |
| - IR_NODE_LINK_TO(conv_bias, fused_conv_op); |
139 |
| - } |
| 138 | + auto get_node_from_conv = [](const patterns::Conv& conv_pattern, |
| 139 | + const GraphPatternDetector::subgraph_t& subgraph) |
| 140 | + -> std::tuple<Node*, Node*, Node*, Node*> { |
| 141 | + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); |
| 142 | + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); |
| 143 | + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); |
| 144 | + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); |
| 145 | + |
| 146 | + return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); |
| 147 | + }; |
| 148 | + |
| 149 | + auto get_node_from_elementwise_add = []( |
| 150 | + const patterns::ElementwiseAdd& elementwise_add_pattern, |
| 151 | + const GraphPatternDetector::subgraph_t& subgraph) |
| 152 | + -> std::tuple<Node*, Node*, Node*> { |
| 153 | + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, |
| 154 | + elementwise_add_pattern); |
| 155 | + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, |
| 156 | + elementwise_add_pattern); |
| 157 | + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, |
| 158 | + elementwise_add_pattern); |
| 159 | + |
| 160 | + return std::make_tuple(elementwise_add_op, elementwise_add_y, |
| 161 | + elementwise_add_out); |
| 162 | + }; |
| 163 | + |
| 164 | + auto handler = |
| 165 | + GenerateFuseHandler(conv_pattern, elementwise_add_pattern, |
| 166 | + get_node_from_conv, get_node_from_elementwise_add); |
| 167 | + gpd(graph.get(), handler); |
140 | 168 |
|
141 |
| - CorrectGraphEdges(g, elementwise_add_out, conv_output); |
142 |
| - GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op}); |
143 |
| - }; |
| 169 | + return graph; |
| 170 | +} |
| 171 | + |
| 172 | +graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( |
| 173 | + const std::string& name_scope_, graph_ptr graph) const { |
| 174 | + GraphPatternDetector gpd; |
| 175 | + auto pattern = gpd.mutable_pattern(); |
| 176 | + |
| 177 | + patterns::Conv conv_pattern{pattern, name_scope_}; |
| 178 | + auto conv_output = conv_pattern(); |
| 179 | + |
| 180 | + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; |
| 181 | + elementwise_add_pattern( |
| 182 | + pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), |
| 183 | + conv_output); |
| 184 | + conv_output->AsIntermediate(); |
144 | 185 |
|
| 186 | + auto get_node_from_conv = [](const patterns::Conv& conv_pattern, |
| 187 | + const GraphPatternDetector::subgraph_t& subgraph) |
| 188 | + -> std::tuple<Node*, Node*, Node*, Node*> { |
| 189 | + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); |
| 190 | + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); |
| 191 | + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); |
| 192 | + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); |
| 193 | + |
| 194 | + return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); |
| 195 | + }; |
| 196 | + |
| 197 | + auto get_node_from_elementwise_add = []( |
| 198 | + const patterns::ElementwiseAdd& elementwise_add_pattern, |
| 199 | + const GraphPatternDetector::subgraph_t& subgraph) |
| 200 | + -> std::tuple<Node*, Node*, Node*> { |
| 201 | + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, |
| 202 | + elementwise_add_pattern); |
| 203 | + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, |
| 204 | + elementwise_add_pattern); |
| 205 | + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, |
| 206 | + elementwise_add_pattern); |
| 207 | + |
| 208 | + return std::make_tuple(elementwise_add_op, elementwise_add_x, |
| 209 | + elementwise_add_out); |
| 210 | + }; |
| 211 | + |
| 212 | + auto handler = |
| 213 | + GenerateFuseHandler(conv_pattern, elementwise_add_pattern, |
| 214 | + get_node_from_conv, get_node_from_elementwise_add); |
145 | 215 | gpd(graph.get(), handler);
|
146 | 216 |
|
147 | 217 | return graph;
|
148 | 218 | }
|
| 219 | + |
| 220 | +graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { |
| 221 | + FusePassBase::Init(name_scope_, graph.get()); |
| 222 | + |
| 223 | + return FuseConvAsY(name_scope_, FuseConvAsX(name_scope_, std::move(graph))); |
| 224 | +} |
149 | 225 | } // namespace ir
|
150 | 226 | } // namespace framework
|
151 | 227 | } // namespace paddle
|
152 | 228 |
|
153 | 229 | REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
|
154 |
| - paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass); |
| 230 | + paddle::framework::ir::ResidualConnectionMKLDNNFusePass); |
0 commit comments