Skip to content

Commit b913463

Browse files
wzzjuroot
authored andcommitted
Update according to the reviewers' suggestion. test=develop
1 parent 3ce6172 commit b913463

File tree

6 files changed

+228
-271
lines changed

6 files changed

+228
-271
lines changed

paddle/fluid/pybind/ir.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ void BindNode(py::module *m) {
148148
})
149149
.def("outputs_append",
150150
[](Node &self, Node &node) { self.outputs.push_back(&node); })
151-
.def_readonly("inputs", &Node::inputs)
152-
.def_readonly("outputs", &Node::outputs);
151+
.def_readwrite("inputs", &Node::inputs)
152+
.def_readwrite("outputs", &Node::outputs);
153153

154154
py::enum_<Node::Type>(node, "Type")
155155
.value("Operation", Node::Type::kOperation)

paddle/fluid/pybind/pybind.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -797,18 +797,18 @@ All parameter, weight, gradient are variables in Paddle.
797797
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
798798
pass.def(py::init())
799799
.def("has", &ir::Pass::Has)
800-
.def("set_program",
800+
.def("set",
801801
[](ir::Pass &self, const std::string &attr_name,
802802
const ProgramDesc &attr) {
803803
return self.Set(attr_name, new ProgramDesc(attr));
804804
})
805805
.def(
806-
"set_str",
806+
"set",
807807
[](ir::Pass &self, const std::string &name, const std::string &attr) {
808808
self.Set<std::string>(name, new std::string(attr));
809809
})
810-
.def("set_int", [](ir::Pass &self, const std::string &name,
811-
int val) { self.Set<const int>(name, new int(val)); })
810+
.def("set", [](ir::Pass &self, const std::string &name,
811+
int val) { self.Set<const int>(name, new int(val)); })
812812
.def("get_program", &ir::Pass::Get<ProgramDesc>)
813813
.def("type", &ir::Pass::Type)
814814
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {

python/paddle/fluid/contrib/slim/graph/graph.py

Lines changed: 1 addition & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -18,140 +18,7 @@
1818
from ....framework import Block
1919
from .... import core
2020

21-
__all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph']
22-
23-
24-
class PyGraph(object):
25-
"""
26-
PyGraph uses core.Graph as the delegation to accomplish the manipulation.
27-
"""
28-
29-
def __init__(self, graph, for_test=False):
30-
"""
31-
Construct the PyGraph using core.Graph.
32-
Args:
33-
graph(core.Graph): C++ Graph.
34-
for_test(bool): True for the test graph and false for the train graph.
35-
"""
36-
assert isinstance(
37-
graph, core.Graph), 'graph must be the instance of core.Graph.'
38-
self.graph = graph
39-
self.for_test = for_test
40-
41-
def is_test(self):
42-
return self.for_test
43-
44-
def all_parameters(self):
45-
param_nodes = set()
46-
for node in self.graph.nodes():
47-
if node.is_var() and node.var() is not None and node.var(
48-
).persistable():
49-
param_nodes.add(node)
50-
return param_nodes
51-
52-
def all_vars(self):
53-
return {node for node in self.graph.nodes() if node.is_var()}
54-
55-
def all_ops(self):
56-
return {node for node in self.graph.nodes() if node.is_op()}
57-
58-
def create_param_node(self, name, var_type, shape, var_dtype):
59-
var_desc = core.VarDesc(name)
60-
var_desc.set_type(var_type)
61-
var_desc.set_shape(shape)
62-
var_desc.set_dtype(var_dtype)
63-
var_desc.set_persistable(True)
64-
return self.graph.create_var_node(var_desc)
65-
66-
def create_var_node(self, name, var_type, shape, var_dtype):
67-
var_desc = core.VarDesc(name)
68-
var_desc.set_type(var_type)
69-
var_desc.set_shape(shape)
70-
var_desc.set_dtype(var_dtype)
71-
return self.graph.create_var_node(var_desc)
72-
73-
def create_var_node_from_desc(self, var_desc):
74-
return self.graph.create_var_node(var_desc)
75-
76-
def create_op_node(self, op_type, attrs, inputs, outputs):
77-
op_desc = core.OpDesc()
78-
op_desc.set_type(op_type)
79-
for attr, value in attrs.iteritems():
80-
self._update_desc_attr(op_desc, attr, value)
81-
for input_name, var_nodes in inputs.iteritems():
82-
if not isinstance(var_nodes, list):
83-
var_nodes = [var_nodes]
84-
op_desc.set_input(input_name,
85-
[var_node.name() for var_node in var_nodes])
86-
for output_name, var_nodes in outputs.iteritems():
87-
if not isinstance(var_nodes, list):
88-
var_nodes = [var_nodes]
89-
op_desc.set_output(output_name,
90-
[var_node.name() for var_node in var_nodes])
91-
return self.graph.create_op_node(op_desc)
92-
93-
def create_op_node_from_desc(self, op_desc):
94-
return self.graph.create_op_node(op_desc)
95-
96-
def _update_desc_attr(self, desc, name, val):
97-
"""
98-
Update the value of desc's attribute by attribute's name.
99-
"""
100-
if isinstance(val, Block):
101-
desc.set_block_attr(name, val.desc)
102-
elif isinstance(val, list) and val and all(
103-
isinstance(v, Block) for v in val):
104-
desc.set_blocks_attr(name, [v.desc for v in val])
105-
elif isinstance(val, core.BlockDesc) or \
106-
isinstance(val, core.ProgramDesc):
107-
desc.set_serialized_attr(name, val.serialize_to_string())
108-
else:
109-
desc._set_attr(name, val)
110-
111-
def safe_remove_nodes(self, remove_nodes):
112-
if not isinstance(remove_nodes, set):
113-
remove_nodes = set(remove_nodes)
114-
core.graph_safe_remove_nodes(self.graph, remove_nodes)
115-
116-
def draw(self, save_path, name, marked_nodes=None):
117-
def _convert_to_pdf(dot_file_path):
118-
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
119-
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
120-
+ ' -o ' + pdf_save_path, shell=True)
121-
if exited_code != 0:
122-
print('The dot command is needed for creating pdf files.')
123-
print('The {} is saved as the dot filetype.'.format(
124-
dot_file_path))
125-
126-
remove_ctr_vars = set()
127-
ops_num = 0
128-
for node in self.graph.nodes():
129-
if node.is_ctrl_var():
130-
remove_ctr_vars.add(node)
131-
elif node.is_op():
132-
ops_num += 1
133-
print('Total ops num = {}.'.format(ops_num))
134-
self.safe_remove_nodes(remove_ctr_vars)
135-
if marked_nodes is not None:
136-
if not isinstance(marked_nodes, set):
137-
marked_nodes = set(marked_nodes)
138-
marked_nodes = marked_nodes - remove_ctr_vars
139-
if self.graph.has('__graphviz__marked_node__'):
140-
self.graph.erase('__graphviz__marked_node__')
141-
self.graph.set('__graphviz__marked_node__', marked_nodes)
142-
viz_dot_path = os.path.join(save_path, name) + '.dot'
143-
viz_pass = core.get_pass('graph_viz_pass')
144-
viz_pass.set_str('graph_viz_path', viz_dot_path)
145-
viz_pass.apply(self.graph)
146-
_convert_to_pdf(viz_dot_path)
147-
148-
def to_program(self):
149-
convert_pass = core.get_pass('graph_to_program_pass')
150-
convert_pass.set_program('program', Program().desc)
151-
convert_pass.apply(self.graph)
152-
desc = convert_pass.get_program('program')
153-
program = Program.construct_from_desc(desc)
154-
return program
21+
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']
15522

15623

15724
class Graph(object):

0 commit comments

Comments
 (0)