@@ -53,8 +53,8 @@ def _optimize_graph(graph, aten):
53
53
54
54
def get_node_id (node ):
55
55
import re
56
- node_id = re .search (r"[\d]+" , node .__str__ ())[ 0 ]
57
- return node_id
56
+ node_id = re .search (r"[\d]+" , node .__str__ ())
57
+ return node_id . group ( 0 )
58
58
59
59
60
60
def pytorch_to_keras (
@@ -103,11 +103,12 @@ def pytorch_to_keras(
103
103
104
104
# Collect graph outputs
105
105
graph_outputs = [n .uniqueName () for n in trace .graph ().outputs ()]
106
+ print ('Graph outputs:' , graph_outputs )
106
107
107
108
# Collect model state dict
108
109
state_dict = _unique_state_dict (model )
109
110
if verbose :
110
- print (list (state_dict ))
111
+ print ('State dict:' , list (state_dict ))
111
112
112
113
import re
113
114
import keras
@@ -173,21 +174,20 @@ def pytorch_to_keras(
173
174
for layer in conf ['layers' ]:
174
175
if layer ['config' ] and 'batch_input_shape' in layer ['config' ]:
175
176
layer ['config' ]['batch_input_shape' ] = \
176
- tuple (np .reshape (
177
+ tuple (np .reshape (np . array (
177
178
[
178
- None ,
179
- * layer ['config' ]['batch_input_shape' ][2 :][:],
180
- layer ['config' ]['batch_input_shape' ][1 ]
181
- ], - 1
179
+ [ None ] +
180
+ list ( layer ['config' ]['batch_input_shape' ][2 :][:]) +
181
+ [ layer ['config' ]['batch_input_shape' ][1 ] ]
182
+ ]) , - 1
182
183
))
183
-
184
184
if layer ['config' ] and 'target_shape' in layer ['config' ]:
185
185
layer ['config' ]['target_shape' ] = \
186
- tuple (np .reshape (
186
+ tuple (np .reshape (np . array (
187
187
[
188
- * layer ['config' ]['target_shape' ][1 :][:],
188
+ list ( layer ['config' ]['target_shape' ][1 :][:]) ,
189
189
layer ['config' ]['target_shape' ][0 ]
190
- ], - 1
190
+ ]) , - 1
191
191
))
192
192
if layer ['config' ] and 'data_format' in layer ['config' ]:
193
193
layer ['config' ]['data_format' ] = 'channels_last'
0 commit comments