Skip to content

Commit 760a555

Browse files
Fix TFJS conversion for old TFJS models (#1629)
* Add conversion of _FusedDepthwiseConv2dNative Signed-off-by: Tom Wildenhain <[email protected]> * Implement PRelu for tfjs Signed-off-by: Tom Wildenhain <[email protected]> * Fix tfjs for old models Signed-off-by: Tom Wildenhain <[email protected]> * Remove unused variable Signed-off-by: Tom Wildenhain <[email protected]> * Pylint Signed-off-by: Tom Wildenhain <[email protected]>
1 parent b15ae91 commit 760a555

File tree

4 files changed

+72
-10
lines changed

4 files changed

+72
-10
lines changed

tests/run_pretrained_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def get_beach(shape):
7171
return get_img(shape, "beach.jpg", np.float32, should_scale=True)
7272

7373

74+
def get_beach_uint8(shape):
75+
"""Get beach image as uint8."""
76+
return get_img(shape, "ade20k.jpg", np.uint8, should_scale=False)
77+
78+
7479
def get_car(shape):
7580
"""Get car image as input."""
7681
return get_img(shape, "car.JPEG", np.float32, should_scale=True)
@@ -152,6 +157,7 @@ def get_sentence():
152157

153158
_INPUT_FUNC_MAPPING = {
154159
"get_beach": get_beach,
160+
"get_beach_uint8": get_beach_uint8,
155161
"get_car": get_car,
156162
"get_ade20k": get_ade20k,
157163
"get_ade20k_uint8": get_ade20k_uint8,

tests/run_pretrained_models.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,3 +670,21 @@ facemesh_tfjs:
670670
- Identity_2:0
671671
rtol: 0.05
672672
atol: 0.0005
673+
674+
ssd_mobilenet_v1_tfjs:
675+
tf_min_version: 2.1
676+
disabled: false
677+
url: https://tfhub.dev/tensorflow/tfjs-model/ssd_mobilenet_v1/1/default/1?tfjs-format=compressed
678+
model: "model.json"
679+
opset_constraints:
680+
"onnx":
681+
"min": 9
682+
model_type: tfjs
683+
input_get: get_beach_uint8
684+
inputs:
685+
"image_tensor:0": [1, 200, 200, 3]
686+
outputs:
687+
- Postprocessor/Slice:0
688+
- Postprocessor/ExpandDims_1:0
689+
rtol: 0.05
690+
atol: 0.0005

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -777,8 +777,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
777777
if input_shape is not None:
778778
new_squeeze_output_shape = [input_shape[i] for i in range(trans_rank) if i not in new_squeeze_axes]
779779
else:
780-
new_squeeze_output_shape = [-1] * trans_rank
781-
self.logger.warning("%s's shape is unknown, which may interfere further optimization", node.input[0])
780+
new_squeeze_output_shape = [-1] * (trans_rank - len(new_squeeze_axes))
782781
self._g.set_shape(node.output[0], new_squeeze_output_shape)
783782
return True
784783
return False

tf2onnx/tfjs_utils.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ def read_tfjs_attr(attr, tf_dtypes=False):
4040
return read_tfjs_attr_helper(k, attr[k], tf_dtypes)
4141

4242

43+
def fix_string_attr(tfjs_node):
44+
"""
45+
Older tfjs models store strings as lists of ints (representing byte values). This function finds and replaces
46+
those strings, so protobuf can correctly decode the json.
47+
"""
48+
def fix(v):
49+
if isinstance(v, list):
50+
return base64.encodebytes(bytes(v)).decode()
51+
return v
52+
if 'attr' not in tfjs_node:
53+
return
54+
for v in tfjs_node['attr'].values():
55+
if 's' in v:
56+
v['s'] = fix(v['s'])
57+
if 'list' in v and 's' in v['list']:
58+
for i, x in enumerate(v['list']['s']):
59+
v['list']['s'][i] = fix(x)
60+
61+
4362
def read_tfjs_attr_helper(k, v, tf_dtypes=False):
4463
"""
4564
A tfjs attribute value is itself a dict with a single key specifying the type and a value with the actual data
@@ -49,12 +68,15 @@ def read_tfjs_attr_helper(k, v, tf_dtypes=False):
4968
supported_types = ['func', 'shape', 'type', 'list', 's', 'i', 'f', 'b']
5069
utils.make_sure(k in supported_types, "Unrecognized tfjs attribute type %s", k)
5170
if k == 'list':
52-
if len(v) == 0:
71+
non_empty_keys = [k2 for k2, v2 in v.items() if len(v2) > 0]
72+
if len(non_empty_keys) == 0:
5373
return []
54-
k2 = list(v.keys())[0]
74+
k2 = non_empty_keys[0]
5575
return [read_tfjs_attr_helper(k2, v2, tf_dtypes) for v2 in v[k2]]
5676
if k == 'type':
57-
dtype = getattr(types_pb2, v)
77+
dtype = v
78+
if not isinstance(dtype, int):
79+
dtype = getattr(types_pb2, dtype)
5880
if not tf_dtypes:
5981
dtype = tf_utils.map_tf_dtype(dtype)
6082
return dtype
@@ -89,6 +111,7 @@ def resolve_output(output, op_info, func_name=None):
89111
# If no port is specified, it is referring to port 0
90112
if output in op_info:
91113
return output + ':0'
114+
# Output isn't from an op and may be an input (no port number)
92115
return output
93116
if cnt == 1:
94117
# Already in our standard format
@@ -146,9 +169,15 @@ def get_output_shapes(node_def, input_dtypes, input_shapes, inp_consts):
146169
"""Returns a list of the output shapes of an op. input_dtypes should be tf dtypes."""
147170
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
148171

149-
if node_def.op == "Prelu":
172+
if node_def.op in ["Prelu", "Enter"]:
150173
return [input_shapes[0]]
151174

175+
if node_def.op == "Merge":
176+
# Find the first non-None shape (if it exists) and return it
177+
non_none = ([t for t in input_shapes if t is not None] + [None])[0]
178+
# The second output of merge is a scalar int indicating which input was selected
179+
return [non_none, []]
180+
152181
del node_def.input[:]
153182
node_def.name = "node"
154183

@@ -355,7 +384,14 @@ def update_shapes(new_shapes):
355384
placeholder_ops = ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]
356385
graph_inputs = [n['name'] + ':0' for n in nodes if n['op'] in placeholder_ops]
357386

358-
unused_outputs = set()
387+
for node in nodes:
388+
if node['op'] == "NextIteration":
389+
# NextIteration nodes can violate the topological sort with cyclic dependencies, so we do them first.
390+
node_name = node['name']
391+
output_name = node_name + ':0'
392+
output_shapes[output_name] = None
393+
tf_dtypes[output_name] = read_tfjs_attr(node['attr']['T'], tf_dtypes=True)
394+
op_info[node_name] = (node['op'], {'dtype': tf_dtypes[output_name]}, [tf_dtypes[output_name]])
359395

360396
for node in nodes:
361397
op_type = node['op']
@@ -376,6 +412,7 @@ def update_shapes(new_shapes):
376412
continue
377413
tf_attr = {}
378414
onnx_attr = {}
415+
fix_string_attr(node)
379416
node_def = tfjs_node_to_tf_node_def(node)
380417
for k, v in node.get('attr', {}).items():
381418
tf_attr[k] = read_tfjs_attr(v, tf_dtypes=True)
@@ -396,7 +433,6 @@ def update_shapes(new_shapes):
396433

397434
input_names = [inp for inp in node.get('input', []) if not inp.startswith('^')]
398435
input_names = [resolve_output(inp, op_info, func_name) for inp in input_names]
399-
unused_outputs.difference_update(input_names)
400436
inp_dtypes = [tf_dtypes[inp] for inp in input_names]
401437
inp_shapes = [output_shapes[inp] for inp in input_names]
402438
inp_consts = [weights.get(inp.split(':')[0]) for inp in input_names]
@@ -407,7 +443,6 @@ def update_shapes(new_shapes):
407443
output_names = [node_name + ":" + str(i) for i in range(len(out_dtypes))]
408444
tf_dtypes.update(zip(output_names, out_dtypes))
409445
update_shapes(zip(output_names, out_shapes))
410-
unused_outputs.update(output_names)
411446

412447
if op_type == "PlaceholderWithDefault":
413448
remove = False
@@ -429,7 +464,11 @@ def update_shapes(new_shapes):
429464

430465
dtypes = {k: tf_utils.map_tf_dtype(v) for k, v in tf_dtypes.items()}
431466
if graph_outputs is None:
432-
graph_outputs = list(unused_outputs)
467+
output_to_node = {out: node.name for node in onnx_nodes for out in node.output}
468+
node_to_outputs = {node.name: list(node.output) for node in onnx_nodes}
469+
used_nodes = set(output_to_node[out] for node in onnx_nodes for out in node.input)
470+
unused_nodes = [node for node in onnx_nodes if node.name not in used_nodes]
471+
graph_outputs = [out for node in unused_nodes for out in node_to_outputs[node.name]]
433472
graph_outputs_mapped = [resolve_output(out, op_info, func_name) for out in graph_outputs]
434473

435474
g = Graph(onnx_nodes, output_shapes, dtypes, input_names=graph_inputs, output_names=graph_outputs_mapped,

0 commit comments

Comments
 (0)