Skip to content

Commit 86fd3b3

Browse files
author
Tomasz Patejko
committed
MKLDNN residual connections fuse pass: counting statistics added to the pass
1 parent ee6f778 commit 86fd3b3

File tree

1 file changed

+44
-5
lines changed

1 file changed

+44
-5
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,56 @@
2121
#include "paddle/fluid/framework/ir/graph.h"
2222
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
2323

24+
#include <boost/optional.hpp>
25+
2426
namespace paddle {
2527
namespace framework {
2628
namespace ir {
2729

30+
// poor replacement for C++17 std::optional and Boost.Optional
31+
struct InPlace {};
32+
InPlace in_place;
33+
34+
template <typename T>
35+
class Maybe {
36+
private:
37+
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
38+
bool is_initialized{false};
39+
40+
public:
41+
template <typename... Args>
42+
explicit Maybe(InPlace, Args&&... args) {
43+
new (&data) T(std::forward<Args>(args)...);
44+
is_initialized = true;
45+
}
46+
47+
Maybe() {}
48+
49+
operator bool() { return is_initialized; }
50+
51+
T& value() { return *reinterpret_cast<T*>(&data); }
52+
53+
~Maybe() { reinterpret_cast<T*>(&data)->~T(); }
54+
};
55+
56+
template <typename T, typename... Args>
57+
Maybe<T> MakeMaybe(Args&&... args) {
58+
return Maybe<T>(in_place, std::forward<Args>(args)...);
59+
}
60+
2861
using graph_ptr = std::unique_ptr<ir::Graph>;
62+
using GraphWithStats = std::pair<ir::Graph*, Maybe<int>>;
2963

3064
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
3165
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
3266
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name);
3367

3468
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
3569
private:
36-
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
37-
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
70+
GraphWithStats FuseConvAsX(const std::string& name_scope,
71+
const GraphWithStats& graph_with_stats) const;
72+
GraphWithStats FuseConvAsY(const std::string& name_scope,
73+
const GraphWithStats& graph_with_stats) const;
3874

3975
template <typename RetType>
4076
using GetNodeFunc =
@@ -48,12 +84,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
4884
const ElementwiseAddFunc& get_node_from_elementwise_add_op,
4985
const CanFuseFunc& can_fuse_func);
5086

87+
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
88+
Graph* graph);
89+
int get_stats() const { return *fusion_stats; }
90+
91+
private:
92+
std::shared_ptr<int> fusion_stats;
5193
ConvFunc get_node_from_conv_op;
5294
ElementwiseAddFunc get_node_from_elementwise_add_op;
5395
CanFuseFunc can_fuse_func;
54-
55-
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
56-
Graph* graph);
5796
};
5897

5998
public:

0 commit comments

Comments
 (0)