Skip to content

Commit 52de798

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into v2_deadlinks
2 parents 333e26d + c7c62e0 commit 52de798

File tree

5 files changed

+181
-1
lines changed

5 files changed

+181
-1
lines changed

paddle/fluid/inference/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ endif()
2121

2222
if(WITH_TESTING)
2323
add_subdirectory(tests/book)
24+
# analysis test depends the models that generate by python/paddle/fluid/tests/book
25+
add_subdirectory(analysis)
2426
endif()
2527

2628
if (TENSORRT_FOUND)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cc_library(dot SRCS dot.cc)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/inference/analysis/dot.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace analysis {
20+
size_t Dot::counter = 0;
21+
} // namespace analysis
22+
} // namespace inference
23+
} // namespace paddle

paddle/fluid/inference/analysis/dot.h

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/*
16+
* This file implements some helper classes and methods for DOT programming
17+
* support. It will give a visualization of the graph and that helps to debug
18+
* the logics of each Pass.
19+
*/
20+
#pragma once
21+
22+
#include <glog/logging.h>
23+
#include <sstream>
24+
#include <unordered_map>
25+
#include <vector>
26+
27+
namespace paddle {
28+
namespace inference {
29+
namespace analysis {
30+
31+
/*
32+
* A Dot template that helps to build a DOT graph definition.
33+
*/
34+
class Dot {
35+
public:
36+
static size_t counter;
37+
38+
struct Attr {
39+
std::string key;
40+
std::string value;
41+
42+
Attr(const std::string& key, const std::string& value)
43+
: key(key), value(value) {}
44+
45+
std::string repr() const {
46+
std::stringstream ss;
47+
ss << key << "=" << '"' << value << '"';
48+
return ss.str();
49+
}
50+
};
51+
52+
struct Node {
53+
std::string name;
54+
std::vector<Attr> attrs;
55+
56+
Node(const std::string& name, const std::vector<Attr>& attrs)
57+
: name(name),
58+
attrs(attrs),
59+
id_("node_" + std::to_string(Dot::counter++)) {}
60+
61+
std::string id() const { return id_; }
62+
63+
std::string repr() const {
64+
std::stringstream ss;
65+
CHECK(!name.empty());
66+
ss << id_;
67+
for (size_t i = 0; i < attrs.size(); i++) {
68+
if (i == 0) {
69+
ss << "[label=" << '"' << name << '"' << " ";
70+
}
71+
ss << attrs[i].repr();
72+
ss << ((i < attrs.size() - 1) ? " " : "]");
73+
}
74+
return ss.str();
75+
}
76+
77+
private:
78+
std::string id_;
79+
};
80+
81+
struct Edge {
82+
std::string source;
83+
std::string target;
84+
std::vector<Attr> attrs;
85+
86+
Edge(const std::string& source, const std::string& target,
87+
const std::vector<Attr>& attrs)
88+
: source(source), target(target), attrs(attrs) {}
89+
90+
std::string repr() const {
91+
std::stringstream ss;
92+
CHECK(!source.empty());
93+
CHECK(!target.empty());
94+
ss << source << "->" << target;
95+
for (size_t i = 0; i < attrs.size(); i++) {
96+
if (i == 0) {
97+
ss << "[";
98+
}
99+
ss << attrs[i].repr();
100+
ss << ((i < attrs.size() - 1) ? " " : "]");
101+
}
102+
return ss.str();
103+
}
104+
};
105+
106+
Dot() = default;
107+
108+
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
109+
110+
void AddNode(const std::string& name, const std::vector<Attr>& attrs) {
111+
CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'";
112+
nodes_.emplace(name, Node{name, attrs});
113+
}
114+
115+
void AddEdge(const std::string& source, const std::string& target,
116+
const std::vector<Attr>& attrs) {
117+
CHECK(!source.empty());
118+
CHECK(!target.empty());
119+
auto sid = nodes_.at(source).id();
120+
auto tid = nodes_.at(target).id();
121+
edges_.emplace_back(sid, tid, attrs);
122+
}
123+
124+
// Compile to DOT language codes.
125+
std::string Build() const {
126+
std::stringstream ss;
127+
const std::string indent = " ";
128+
ss << "digraph G {" << '\n';
129+
130+
// Add graph attrs
131+
for (const auto& attr : attrs_) {
132+
ss << indent << attr.repr() << '\n';
133+
}
134+
// add nodes
135+
for (auto& item : nodes_) {
136+
ss << indent << item.second.repr() << '\n';
137+
}
138+
// add edges
139+
for (auto& edge : edges_) {
140+
ss << indent << edge.repr() << '\n';
141+
}
142+
ss << "} // end G";
143+
return ss.str();
144+
}
145+
146+
private:
147+
std::unordered_map<std::string, Node> nodes_;
148+
std::vector<Edge> edges_;
149+
std::vector<Attr> attrs_;
150+
};
151+
152+
} // namespace analysis
153+
} // namespace inference
154+
} // namespace paddle

python/paddle/dataset/wmt16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
9696
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
9797
TOTAL_DE_WORDS))
9898
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
99-
TOTAL_ENG_WORDS))
99+
TOTAL_EN_WORDS))
100100
return src_dict_size, trg_dict_size
101101

102102

0 commit comments

Comments
 (0)