|
18 | 18 | from ....framework import Block
|
19 | 19 | from .... import core
|
20 | 20 |
|
21 |
| -__all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph'] |
22 |
| - |
23 |
| - |
24 |
| -class PyGraph(object): |
25 |
| - """ |
26 |
| - PyGraph uses core.Graph as the delegation to accomplish the manipulation. |
27 |
| - """ |
28 |
| - |
29 |
| - def __init__(self, graph, for_test=False): |
30 |
| - """ |
31 |
| - Construct the PyGraph using core.Graph. |
32 |
| - Args: |
33 |
| - graph(core.Graph): C++ Graph. |
34 |
| - for_test(bool): True for the test graph and false for the train graph. |
35 |
| - """ |
36 |
| - assert isinstance( |
37 |
| - graph, core.Graph), 'graph must be the instance of core.Graph.' |
38 |
| - self.graph = graph |
39 |
| - self.for_test = for_test |
40 |
| - |
41 |
| - def is_test(self): |
42 |
| - return self.for_test |
43 |
| - |
44 |
| - def all_parameters(self): |
45 |
| - param_nodes = set() |
46 |
| - for node in self.graph.nodes(): |
47 |
| - if node.is_var() and node.var() is not None and node.var( |
48 |
| - ).persistable(): |
49 |
| - param_nodes.add(node) |
50 |
| - return param_nodes |
51 |
| - |
52 |
| - def all_vars(self): |
53 |
| - return {node for node in self.graph.nodes() if node.is_var()} |
54 |
| - |
55 |
| - def all_ops(self): |
56 |
| - return {node for node in self.graph.nodes() if node.is_op()} |
57 |
| - |
58 |
| - def create_param_node(self, name, var_type, shape, var_dtype): |
59 |
| - var_desc = core.VarDesc(name) |
60 |
| - var_desc.set_type(var_type) |
61 |
| - var_desc.set_shape(shape) |
62 |
| - var_desc.set_dtype(var_dtype) |
63 |
| - var_desc.set_persistable(True) |
64 |
| - return self.graph.create_var_node(var_desc) |
65 |
| - |
66 |
| - def create_var_node(self, name, var_type, shape, var_dtype): |
67 |
| - var_desc = core.VarDesc(name) |
68 |
| - var_desc.set_type(var_type) |
69 |
| - var_desc.set_shape(shape) |
70 |
| - var_desc.set_dtype(var_dtype) |
71 |
| - return self.graph.create_var_node(var_desc) |
72 |
| - |
73 |
| - def create_var_node_from_desc(self, var_desc): |
74 |
| - return self.graph.create_var_node(var_desc) |
75 |
| - |
76 |
| - def create_op_node(self, op_type, attrs, inputs, outputs): |
77 |
| - op_desc = core.OpDesc() |
78 |
| - op_desc.set_type(op_type) |
79 |
| - for attr, value in attrs.iteritems(): |
80 |
| - self._update_desc_attr(op_desc, attr, value) |
81 |
| - for input_name, var_nodes in inputs.iteritems(): |
82 |
| - if not isinstance(var_nodes, list): |
83 |
| - var_nodes = [var_nodes] |
84 |
| - op_desc.set_input(input_name, |
85 |
| - [var_node.name() for var_node in var_nodes]) |
86 |
| - for output_name, var_nodes in outputs.iteritems(): |
87 |
| - if not isinstance(var_nodes, list): |
88 |
| - var_nodes = [var_nodes] |
89 |
| - op_desc.set_output(output_name, |
90 |
| - [var_node.name() for var_node in var_nodes]) |
91 |
| - return self.graph.create_op_node(op_desc) |
92 |
| - |
93 |
| - def create_op_node_from_desc(self, op_desc): |
94 |
| - return self.graph.create_op_node(op_desc) |
95 |
| - |
96 |
| - def _update_desc_attr(self, desc, name, val): |
97 |
| - """ |
98 |
| - Update the value of desc's attribute by attribute's name. |
99 |
| - """ |
100 |
| - if isinstance(val, Block): |
101 |
| - desc.set_block_attr(name, val.desc) |
102 |
| - elif isinstance(val, list) and val and all( |
103 |
| - isinstance(v, Block) for v in val): |
104 |
| - desc.set_blocks_attr(name, [v.desc for v in val]) |
105 |
| - elif isinstance(val, core.BlockDesc) or \ |
106 |
| - isinstance(val, core.ProgramDesc): |
107 |
| - desc.set_serialized_attr(name, val.serialize_to_string()) |
108 |
| - else: |
109 |
| - desc._set_attr(name, val) |
110 |
| - |
111 |
| - def safe_remove_nodes(self, remove_nodes): |
112 |
| - if not isinstance(remove_nodes, set): |
113 |
| - remove_nodes = set(remove_nodes) |
114 |
| - core.graph_safe_remove_nodes(self.graph, remove_nodes) |
115 |
| - |
116 |
| - def draw(self, save_path, name, marked_nodes=None): |
117 |
| - def _convert_to_pdf(dot_file_path): |
118 |
| - pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' |
119 |
| - exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ |
120 |
| - + ' -o ' + pdf_save_path, shell=True) |
121 |
| - if exited_code != 0: |
122 |
| - print('The dot command is needed for creating pdf files.') |
123 |
| - print('The {} is saved as the dot filetype.'.format( |
124 |
| - dot_file_path)) |
125 |
| - |
126 |
| - remove_ctr_vars = set() |
127 |
| - ops_num = 0 |
128 |
| - for node in self.graph.nodes(): |
129 |
| - if node.is_ctrl_var(): |
130 |
| - remove_ctr_vars.add(node) |
131 |
| - elif node.is_op(): |
132 |
| - ops_num += 1 |
133 |
| - print('Total ops num = {}.'.format(ops_num)) |
134 |
| - self.safe_remove_nodes(remove_ctr_vars) |
135 |
| - if marked_nodes is not None: |
136 |
| - if not isinstance(marked_nodes, set): |
137 |
| - marked_nodes = set(marked_nodes) |
138 |
| - marked_nodes = marked_nodes - remove_ctr_vars |
139 |
| - if self.graph.has('__graphviz__marked_node__'): |
140 |
| - self.graph.erase('__graphviz__marked_node__') |
141 |
| - self.graph.set('__graphviz__marked_node__', marked_nodes) |
142 |
| - viz_dot_path = os.path.join(save_path, name) + '.dot' |
143 |
| - viz_pass = core.get_pass('graph_viz_pass') |
144 |
| - viz_pass.set_str('graph_viz_path', viz_dot_path) |
145 |
| - viz_pass.apply(self.graph) |
146 |
| - _convert_to_pdf(viz_dot_path) |
147 |
| - |
148 |
| - def to_program(self): |
149 |
| - convert_pass = core.get_pass('graph_to_program_pass') |
150 |
| - convert_pass.set_program('program', Program().desc) |
151 |
| - convert_pass.apply(self.graph) |
152 |
| - desc = convert_pass.get_program('program') |
153 |
| - program = Program.construct_from_desc(desc) |
154 |
| - return program |
| 21 | +__all__ = ['Graph', 'ImitationGraph', 'IRGraph'] |
155 | 22 |
|
156 | 23 |
|
157 | 24 | class Graph(object):
|
|
0 commit comments