@@ -40,6 +40,25 @@ def read_tfjs_attr(attr, tf_dtypes=False):
40
40
return read_tfjs_attr_helper (k , attr [k ], tf_dtypes )
41
41
42
42
43
+ def fix_string_attr (tfjs_node ):
44
+ """
45
+ Older tfjs models store strings as lists of ints (representing byte values). This function finds and replaces
46
+ those strings, so protobuf can correctly decode the json.
47
+ """
48
+ def fix (v ):
49
+ if isinstance (v , list ):
50
+ return base64 .encodebytes (bytes (v )).decode ()
51
+ return v
52
+ if 'attr' not in tfjs_node :
53
+ return
54
+ for v in tfjs_node ['attr' ].values ():
55
+ if 's' in v :
56
+ v ['s' ] = fix (v ['s' ])
57
+ if 'list' in v and 's' in v ['list' ]:
58
+ for i , x in enumerate (v ['list' ]['s' ]):
59
+ v ['list' ]['s' ][i ] = fix (x )
60
+
61
+
43
62
def read_tfjs_attr_helper (k , v , tf_dtypes = False ):
44
63
"""
45
64
A tfjs attribute value is itself a dict with a single key specifying the type and a value with the actual data
@@ -49,12 +68,15 @@ def read_tfjs_attr_helper(k, v, tf_dtypes=False):
49
68
supported_types = ['func' , 'shape' , 'type' , 'list' , 's' , 'i' , 'f' , 'b' ]
50
69
utils .make_sure (k in supported_types , "Unrecognized tfjs attribute type %s" , k )
51
70
if k == 'list' :
52
- if len (v ) == 0 :
71
+ non_empty_keys = [k2 for k2 , v2 in v .items () if len (v2 ) > 0 ]
72
+ if len (non_empty_keys ) == 0 :
53
73
return []
54
- k2 = list ( v . keys ()) [0 ]
74
+ k2 = non_empty_keys [0 ]
55
75
return [read_tfjs_attr_helper (k2 , v2 , tf_dtypes ) for v2 in v [k2 ]]
56
76
if k == 'type' :
57
- dtype = getattr (types_pb2 , v )
77
+ dtype = v
78
+ if not isinstance (dtype , int ):
79
+ dtype = getattr (types_pb2 , dtype )
58
80
if not tf_dtypes :
59
81
dtype = tf_utils .map_tf_dtype (dtype )
60
82
return dtype
@@ -89,6 +111,7 @@ def resolve_output(output, op_info, func_name=None):
89
111
# If no port is specified, it is referring to port 0
90
112
if output in op_info :
91
113
return output + ':0'
114
+ # Output isn't from an op and may be an input (no port number)
92
115
return output
93
116
if cnt == 1 :
94
117
# Already in our standard format
@@ -146,9 +169,15 @@ def get_output_shapes(node_def, input_dtypes, input_shapes, inp_consts):
146
169
"""Returns a list of the output shapes of an op. input_dtypes should be tf dtypes."""
147
170
from tf2onnx .tf_loader import tf_session , tf_placeholder # pylint: disable=import-outside-toplevel
148
171
149
- if node_def .op == "Prelu" :
172
+ if node_def .op in [ "Prelu" , "Enter" ] :
150
173
return [input_shapes [0 ]]
151
174
175
+ if node_def .op == "Merge" :
176
+ # Find the first non-None shape (if it exists) and return it
177
+ non_none = ([t for t in input_shapes if t is not None ] + [None ])[0 ]
178
+ # The second output of merge is a scalar int indicating which input was selected
179
+ return [non_none , []]
180
+
152
181
del node_def .input [:]
153
182
node_def .name = "node"
154
183
@@ -355,7 +384,14 @@ def update_shapes(new_shapes):
355
384
placeholder_ops = ["Placeholder" , "PlaceholderWithDefault" , "PlaceholderV2" ]
356
385
graph_inputs = [n ['name' ] + ':0' for n in nodes if n ['op' ] in placeholder_ops ]
357
386
358
- unused_outputs = set ()
387
+ for node in nodes :
388
+ if node ['op' ] == "NextIteration" :
389
+ # NextIteration nodes can violate the topological sort with cyclic dependencies, so we do them first.
390
+ node_name = node ['name' ]
391
+ output_name = node_name + ':0'
392
+ output_shapes [output_name ] = None
393
+ tf_dtypes [output_name ] = read_tfjs_attr (node ['attr' ]['T' ], tf_dtypes = True )
394
+ op_info [node_name ] = (node ['op' ], {'dtype' : tf_dtypes [output_name ]}, [tf_dtypes [output_name ]])
359
395
360
396
for node in nodes :
361
397
op_type = node ['op' ]
@@ -376,6 +412,7 @@ def update_shapes(new_shapes):
376
412
continue
377
413
tf_attr = {}
378
414
onnx_attr = {}
415
+ fix_string_attr (node )
379
416
node_def = tfjs_node_to_tf_node_def (node )
380
417
for k , v in node .get ('attr' , {}).items ():
381
418
tf_attr [k ] = read_tfjs_attr (v , tf_dtypes = True )
@@ -396,7 +433,6 @@ def update_shapes(new_shapes):
396
433
397
434
input_names = [inp for inp in node .get ('input' , []) if not inp .startswith ('^' )]
398
435
input_names = [resolve_output (inp , op_info , func_name ) for inp in input_names ]
399
- unused_outputs .difference_update (input_names )
400
436
inp_dtypes = [tf_dtypes [inp ] for inp in input_names ]
401
437
inp_shapes = [output_shapes [inp ] for inp in input_names ]
402
438
inp_consts = [weights .get (inp .split (':' )[0 ]) for inp in input_names ]
@@ -407,7 +443,6 @@ def update_shapes(new_shapes):
407
443
output_names = [node_name + ":" + str (i ) for i in range (len (out_dtypes ))]
408
444
tf_dtypes .update (zip (output_names , out_dtypes ))
409
445
update_shapes (zip (output_names , out_shapes ))
410
- unused_outputs .update (output_names )
411
446
412
447
if op_type == "PlaceholderWithDefault" :
413
448
remove = False
@@ -429,7 +464,11 @@ def update_shapes(new_shapes):
429
464
430
465
dtypes = {k : tf_utils .map_tf_dtype (v ) for k , v in tf_dtypes .items ()}
431
466
if graph_outputs is None :
432
- graph_outputs = list (unused_outputs )
467
+ output_to_node = {out : node .name for node in onnx_nodes for out in node .output }
468
+ node_to_outputs = {node .name : list (node .output ) for node in onnx_nodes }
469
+ used_nodes = set (output_to_node [out ] for node in onnx_nodes for out in node .input )
470
+ unused_nodes = [node for node in onnx_nodes if node .name not in used_nodes ]
471
+ graph_outputs = [out for node in unused_nodes for out in node_to_outputs [node .name ]]
433
472
graph_outputs_mapped = [resolve_output (out , op_info , func_name ) for out in graph_outputs ]
434
473
435
474
g = Graph (onnx_nodes , output_shapes , dtypes , input_names = graph_inputs , output_names = graph_outputs_mapped ,
0 commit comments