Skip to content

Commit cdb5f29

Browse files
authored
Add a C++ program that prints operator document in JSON format (#4981)
* Add print_operators_doc.cc * Update Escape * Correct a bug * Remove OpInfoMap::Iterate * Update the print_operators_doc.cc * Escape tab * Use auto& * Use auto& * Remove trailing , * clang-format C++
1 parent db157ed commit cdb5f29

File tree

4 files changed

+146
-14
lines changed

4 files changed

+146
-14
lines changed

paddle/framework/op_info.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,8 @@ class OpInfoMap {
8787
}
8888
}
8989

90-
template <typename Callback>
91-
void IterAllInfo(Callback callback) {
92-
for (auto& it : map_) {
93-
callback(it.first, it.second);
94-
}
90+
const std::unordered_map<std::string, const OpInfo>& map() const {
91+
return map_;
9592
}
9693

9794
private:

paddle/pybind/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ if(WITH_PYTHON)
44
DEPS pybind python backward proto_desc tensor_array paddle_memory executor
55
${GLOB_OP_LIB})
66
endif(WITH_PYTHON)
7+
8+
cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB} tensor_array)

paddle/pybind/print_operators_doc.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#include <iostream>
2+
#include <sstream> // std::stringstream
3+
#include <string>
4+
5+
#include "paddle/framework/op_info.h"
6+
#include "paddle/framework/op_registry.h"
7+
#include "paddle/pybind/pybind.h"
8+
9+
std::string Escape(const std::string& s) {
10+
std::string r;
11+
for (size_t i = 0; i < s.size(); i++) {
12+
switch (s[i]) {
13+
case '\"':
14+
r += "\\\"";
15+
break;
16+
case '\\':
17+
r += "\\\\";
18+
break;
19+
case '\n':
20+
r += "\\n";
21+
break;
22+
case '\t':
23+
r += "\\t";
24+
case '\r':
25+
break;
26+
default:
27+
r += s[i];
28+
break;
29+
}
30+
}
31+
return r;
32+
}
33+
34+
std::string AttrType(paddle::framework::AttrType at) {
35+
switch (at) {
36+
case paddle::framework::INT:
37+
return "int";
38+
case paddle::framework::FLOAT:
39+
return "float";
40+
case paddle::framework::STRING:
41+
return "string";
42+
case paddle::framework::BOOLEAN:
43+
return "bool";
44+
case paddle::framework::INTS:
45+
return "int array";
46+
case paddle::framework::FLOATS:
47+
return "float array";
48+
case paddle::framework::STRINGS:
49+
return "string array";
50+
case paddle::framework::BOOLEANS:
51+
return "bool array";
52+
case paddle::framework::BLOCK:
53+
return "block id";
54+
}
55+
return "UNKNOWN"; // not possible
56+
}
57+
58+
void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) {
59+
ss << " { "
60+
<< "\n"
61+
<< " \"name\" : \"" << Escape(v.name()) << "\",\n"
62+
<< " \"comment\" : \"" << Escape(v.comment()) << "\",\n"
63+
<< " \"duplicable\" : " << v.duplicable() << ",\n"
64+
<< " \"intermediate\" : " << v.intermediate() << "\n"
65+
<< " },";
66+
}
67+
68+
void PrintAttr(const paddle::framework::OpProto::Attr& a,
69+
std::stringstream& ss) {
70+
ss << " { "
71+
<< "\n"
72+
<< " \"name\" : \"" << Escape(a.name()) << "\",\n"
73+
<< " \"type\" : \"" << AttrType(a.type()) << "\",\n"
74+
<< " \"comment\" : \"" << Escape(a.comment()) << "\",\n"
75+
<< " \"generated\" : " << a.generated() << "\n"
76+
<< " },";
77+
}
78+
79+
void PrintOpProto(const std::string& type,
80+
const paddle::framework::OpInfo& opinfo,
81+
std::stringstream& ss) {
82+
std::cerr << "Processing " << type << "\n";
83+
84+
const paddle::framework::OpProto* p = opinfo.proto_;
85+
if (p == nullptr) {
86+
return; // It is possible that an operator doesn't have OpProto.
87+
}
88+
89+
ss << "{\n"
90+
<< " \"type\" : \"" << Escape(p->type()) << "\",\n"
91+
<< " \"comment\" : \"" << Escape(p->comment()) << "\",\n";
92+
93+
ss << " \"inputs\" : [ "
94+
<< "\n";
95+
for (int i = 0; i < p->inputs_size(); i++) {
96+
PrintVar(p->inputs(i), ss);
97+
}
98+
ss.seekp(-1, ss.cur); // remove the trailing comma
99+
ss << " ], "
100+
<< "\n";
101+
102+
ss << " \"outputs\" : [ "
103+
<< "\n";
104+
for (int i = 0; i < p->outputs_size(); i++) {
105+
PrintVar(p->outputs(i), ss);
106+
}
107+
ss.seekp(-1, ss.cur); // remove the trailing comma
108+
ss << " ], "
109+
<< "\n";
110+
111+
ss << " \"attrs\" : [ "
112+
<< "\n";
113+
for (int i = 0; i < p->attrs_size(); i++) {
114+
PrintAttr(p->attrs(i), ss);
115+
}
116+
ss.seekp(-1, ss.cur); // remove the trailing comma
117+
ss << " ] "
118+
<< "\n";
119+
120+
ss << "},";
121+
}
122+
123+
int main() {
124+
std::stringstream ss;
125+
ss << "[\n";
126+
for (auto& iter : paddle::framework::OpInfoMap::Instance().map()) {
127+
PrintOpProto(iter.first, iter.second, ss);
128+
}
129+
ss.seekp(-1, ss.cur); // remove the trailing comma
130+
ss << "]\n";
131+
std::cout << ss.str();
132+
}

paddle/pybind/pybind.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,16 @@ All parameter, weight, gradient are variables in Paddle.
225225
//! Python str. If you want a str object, you should cast them in Python.
226226
m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
227227
std::vector<py::bytes> ret_values;
228-
229-
OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type,
230-
const OpInfo &info) {
231-
if (!info.HasOpProtoAndChecker()) return;
232-
std::string str;
233-
PADDLE_ENFORCE(info.Proto().SerializeToString(&str),
234-
"Serialize OpProto Error. This could be a bug of Paddle.");
235-
ret_values.emplace_back(str);
236-
});
228+
for (auto &iter : OpInfoMap::Instance().map()) {
229+
auto &info = iter.second;
230+
if (info.HasOpProtoAndChecker()) {
231+
std::string str;
232+
PADDLE_ENFORCE(
233+
info.Proto().SerializeToString(&str),
234+
"Serialize OpProto Error. This could be a bug of Paddle.");
235+
ret_values.emplace_back(str);
236+
}
237+
}
237238
return ret_values;
238239
});
239240
m.def_submodule(

0 commit comments

Comments
 (0)