@@ -50,6 +50,12 @@ def _optimize_trace(trace, aten):
50
50
torch ._C ._jit_pass_lint (trace )
51
51
52
52
53
+ def get_node_id (node ):
54
+ import re
55
+ node_id = re .search (r"[\d]+" , node .__str__ ())[0 ]
56
+ return node_id
57
+
58
+
53
59
def pytorch_to_keras (
54
60
model , args , input_shape ,
55
61
change_ordering = False , training = False , verbose = False
@@ -84,6 +90,9 @@ def pytorch_to_keras(
84
90
85
91
_optimize_trace (trace , False )
86
92
93
+ if verbose :
94
+ print (trace .graph ())
95
+
87
96
if verbose :
88
97
print (list (trace .graph ().outputs ()))
89
98
@@ -115,14 +124,16 @@ def pytorch_to_keras(
115
124
node_input_names = []
116
125
for node_input in node_inputs :
117
126
if node_input .node ().scopeName ():
118
- node_input_names .append (node_input .node (). scopeName ( ))
127
+ node_input_names .append (get_node_id ( node_input .node ()))
119
128
120
129
if len (node_input_names ) == 0 :
121
130
node_input_names .append ('input' )
122
131
123
132
node_type = node .kind ()
133
+ # print(dir(node))
134
+
124
135
node_scope_name = node .scopeName ()
125
- node_id = re . search ( r"[\d]+" , node . __str__ ())[ 0 ]
136
+ node_id = get_node_id ( node )
126
137
node_weights_name = '.' .join (
127
138
re .findall (r'\[([\w\d.]+)\]' , node_scope_name )
128
139
)
@@ -145,12 +156,12 @@ def pytorch_to_keras(
145
156
print ('is_terminal:' , node_id in graph_outputs )
146
157
AVAILABLE_CONVERTERS [node_type ](
147
158
node_attrs ,
148
- node_weights_name , node_scope_name ,
159
+ node_weights_name , node_id ,
149
160
node_input_names ,
150
161
layers , state_dict
151
162
)
152
163
if node_id in graph_outputs :
153
- outputs .append (layers [node_scope_name ])
164
+ outputs .append (layers [node_id ])
154
165
155
166
model = keras .models .Model (inputs = layers ['input' ], outputs = outputs )
156
167
0 commit comments