Skip to content

Commit f143c9a

Browse files
authored
port VisualDL's graphviz theme to IR (#13246)
1 parent 6e03f79 commit f143c9a

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

paddle/fluid/framework/ir/graph_viz_pass.cc

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,37 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
5050

5151
Dot dot;
5252

53-
std::vector<Dot::Attr> op_attrs({Dot::Attr("style", "filled"),
54-
Dot::Attr("shape", "box"),
55-
Dot::Attr("fillcolor", "red")});
56-
std::vector<Dot::Attr> var_attrs({Dot::Attr("style", "filled,rounded"),
57-
// Dot::Attr("shape", "diamond"),
58-
Dot::Attr("fillcolor", "yellow")});
59-
60-
std::vector<Dot::Attr> marked_op_attrs({Dot::Attr("style", "filled"),
61-
Dot::Attr("shape", "box"),
62-
Dot::Attr("fillcolor", "lightgray")});
63-
std::vector<Dot::Attr> marked_var_attrs(
64-
{Dot::Attr("style", "filled,rounded"),
65-
// Dot::Attr("shape", "diamond"),
66-
Dot::Attr("fillcolor", "lightgray")});
53+
const std::vector<Dot::Attr> op_attrs({
54+
Dot::Attr("style", "rounded,filled,bold"), //
55+
Dot::Attr("shape", "box"), //
56+
Dot::Attr("color", "#303A3A"), //
57+
Dot::Attr("fontcolor", "#ffffff"), //
58+
Dot::Attr("width", "1.3"), //
59+
Dot::Attr("height", "0.84"), //
60+
Dot::Attr("fontname", "Arial"), //
61+
});
62+
const std::vector<Dot::Attr> arg_attrs({
63+
Dot::Attr("shape", "box"), //
64+
Dot::Attr("style", "rounded,filled,bold"), //
65+
Dot::Attr("fontname", "Arial"), //
66+
Dot::Attr("fillcolor", "#999999"), //
67+
Dot::Attr("color", "#dddddd"), //
68+
});
69+
70+
const std::vector<Dot::Attr> param_attrs({
71+
Dot::Attr("shape", "box"), //
72+
Dot::Attr("style", "rounded,filled,bold"), //
73+
Dot::Attr("fontname", "Arial"), //
74+
Dot::Attr("color", "#148b97"), //
75+
Dot::Attr("fontcolor", "#ffffff"), //
76+
});
77+
78+
const std::vector<Dot::Attr> marked_op_attrs(
79+
{Dot::Attr("style", "rounded,filled,bold"), Dot::Attr("shape", "box"),
80+
Dot::Attr("fillcolor", "yellow")});
81+
const std::vector<Dot::Attr> marked_var_attrs(
82+
{Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"),
83+
Dot::Attr("fillcolor", "yellow")});
6784

6885
auto marked_nodes = ConsumeMarkedNodes(graph.get());
6986
// Create nodes
@@ -74,9 +91,17 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
7491
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
7592
dot.AddNode(node_id, attr, node_id);
7693
} else if (n->IsVar()) {
77-
decltype(op_attrs) attr =
78-
marked_nodes.count(n) ? marked_var_attrs : var_attrs;
79-
dot.AddNode(node_id, attr, node_id);
94+
decltype(op_attrs)* attr;
95+
if (marked_nodes.count(n)) {
96+
attr = &marked_var_attrs;
97+
} else if (const_cast<Node*>(n)->Var() &&
98+
const_cast<Node*>(n)->Var()->Persistable()) {
99+
attr = &param_attrs;
100+
} else {
101+
attr = &arg_attrs;
102+
}
103+
104+
dot.AddNode(node_id, *attr, node_id);
80105
}
81106
node2dot[n] = node_id;
82107
}

0 commit comments

Comments
 (0)