21
21
#include " paddle/fluid/framework/ir/graph.h"
22
22
#include " paddle/fluid/framework/ir/graph_pattern_detector.h"
23
23
24
+ #include < boost/optional.hpp>
25
+
24
26
namespace paddle {
25
27
namespace framework {
26
28
namespace ir {
27
29
30
+ // poor replacement for C++17 std::optional and Boost.Optional
31
+ struct InPlace {};
32
+ InPlace in_place;
33
+
34
+ template <typename T>
35
+ class Maybe {
36
+ private:
37
+ typename std::aligned_storage<sizeof (T), alignof (T)>::type data;
38
+ bool is_initialized{false };
39
+
40
+ public:
41
+ template <typename ... Args>
42
+ explicit Maybe (InPlace, Args&&... args) {
43
+ new (&data) T (std::forward<Args>(args)...);
44
+ is_initialized = true ;
45
+ }
46
+
47
+ Maybe () {}
48
+
49
+ operator bool () { return is_initialized; }
50
+
51
+ T& value () { return *reinterpret_cast <T*>(&data); }
52
+
53
+ ~Maybe () { reinterpret_cast <T*>(&data)->~T (); }
54
+ };
55
+
56
+ template <typename T, typename ... Args>
57
+ Maybe<T> MakeMaybe (Args&&... args) {
58
+ return Maybe<T>(in_place, std::forward<Args>(args)...);
59
+ }
60
+
28
61
using graph_ptr = std::unique_ptr<ir::Graph>;
62
+ using GraphWithStats = std::pair<ir::Graph*, Maybe<int >>;
29
63
30
64
void CorrectGraphEdges (Graph* graph, Node* from, Node* to);
31
65
bool IsReachable (ir::Graph* graph, Node* from, Node* to);
32
66
std::pair<bool , Node*> HasBias (const Node& op, const std::string& bias_name);
33
67
34
68
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
35
69
private:
36
- graph_ptr FuseConvAsX (const std::string& name_scope_, graph_ptr graph) const ;
37
- graph_ptr FuseConvAsY (const std::string& name_scope_, graph_ptr graph) const ;
70
+ GraphWithStats FuseConvAsX (const std::string& name_scope,
71
+ const GraphWithStats& graph_with_stats) const ;
72
+ GraphWithStats FuseConvAsY (const std::string& name_scope,
73
+ const GraphWithStats& graph_with_stats) const ;
38
74
39
75
template <typename RetType>
40
76
using GetNodeFunc =
@@ -48,12 +84,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
48
84
const ElementwiseAddFunc& get_node_from_elementwise_add_op,
49
85
const CanFuseFunc& can_fuse_func);
50
86
87
+ void operator ()(const GraphPatternDetector::subgraph_t & subgraph,
88
+ Graph* graph);
89
+ int get_stats () const { return *fusion_stats; }
90
+
91
+ private:
92
+ std::shared_ptr<int > fusion_stats;
51
93
ConvFunc get_node_from_conv_op;
52
94
ElementwiseAddFunc get_node_from_elementwise_add_op;
53
95
CanFuseFunc can_fuse_func;
54
-
55
- void operator ()(const GraphPatternDetector::subgraph_t & subgraph,
56
- Graph* graph);
57
96
};
58
97
59
98
public:
0 commit comments