Skip to content

Commit b055709

Browse files
committed
make mobilenet work for pytorch
1 parent 48ca297 commit b055709

File tree

4 files changed

+11
-74
lines changed

4 files changed

+11
-74
lines changed

tests/test_graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,8 @@ def test_relu6(self):
231231
_ = tf.identity(x_, name="output")
232232
g = process_tf_graph(sess.graph)
233233
self.assertEqual(
234-
'digraph { Relu6__3 [op_type=Const] Relu6__2 [op_type=Const] '
235-
'input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Max] '
236-
'Relu6__4 [op_type=Min] output [op_type=Identity] input1:0 -> Relu6 '
237-
'Relu6__2 -> Relu6 Relu6:0 -> Relu6__4 Relu6__3 -> Relu6__4 Relu6__4:0 -> output }',
234+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Relu] Relu6__2 [op_type=Clip] '
235+
'output [op_type=Identity] input1:0 -> Relu6 Relu6:0 -> Relu6__2 Relu6__2:0 -> output }',
238236
onnx_to_graphviz(g))
239237

240238
def test_conv2d(self):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def _initialize_handlers(self):
187187
"Pad": self._pad_handler,
188188
"ReduceMean": self._reducemean_handler,
189189
"Relu": self._simple_through_handler,
190+
"Clip": self._simple_through_handler,
190191
"Slice": self._slice_handler,
191192
"Split": self._split_handler,
192193
"Tanh": self._simple_through_handler,

tf2onnx/tfonnx.py

Lines changed: 5 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -634,75 +634,11 @@ def pool_op(ctx, node, name, args):
634634

635635
def relu6_op(ctx, node, name, args):
636636
# relu6 = min(max(features, 0), 6)
637-
# since onnx does not have relu6, compose it with multiple ops.
638-
old_output = node.output[0]
639-
dtype = ctx.get_dtype(node.input[0])
640-
dtype = utils.ONNX_TO_NUMPY_DTYPE[dtype] if dtype else np.float32
641-
shape = ctx.get_shape(node.input[0])
642-
nodes = []
643-
if -1 in shape:
644-
# if the shape has unknown dims we need to do something like this for opset < 8 (=no broadcast for min/max):
645-
# tz = sub(features, features)
646-
# t6 = add(6, tz)
647-
# relu6 = min(max(features, t0), t6)
648-
input_node = node.inputs[0]
649-
node.type = "Max"
650-
651-
# const tensor 6
652-
six_name = utils.make_name(node.name)
653-
nodes.append(ctx.make_const(six_name, np.array([6.], dtype=dtype)))
654-
655-
# get a tensor of input shape with zeros
656-
sub_node = ctx.make_node("Sub", [node.input[0], node.input[0]], op_name_scope=input_node.name)
657-
node.input.append(sub_node.output[0])
658-
659-
# get a tensor of input shape with 6
660-
add_node = ctx.make_node("Add", [six_name, sub_node.output[0]], op_name_scope=input_node.name)
661-
662-
min_name = utils.make_name(node.name)
663-
min_node = ctx.insert_new_node_on_output("Min", node.output[0], name=min_name)
664-
min_node.input.append(add_node.output[0])
665-
ctx.copy_shape(old_output, min_node.output[0])
666-
nodes.extend([sub_node, add_node, node, min_node])
667-
return nodes
668-
669-
# if there is no unknown dim in shape we can use constants
670-
node.type = "Max"
671-
zero_name = utils.make_name(node.name)
672-
nodes.append(ctx.make_const(zero_name, np.zeros(shape, dtype=dtype)))
673-
six_name = utils.make_name(node.name)
674-
six = np.zeros(shape, dtype=dtype)
675-
six.fill(6)
676-
nodes.append(ctx.make_const(six_name, six))
677-
node.input.append(zero_name)
678-
min_name = utils.make_name(node.name)
679-
min_node = ctx.insert_new_node_on_output("Min", node.output[0], name=min_name)
680-
min_node.input.append(six_name)
681-
ctx.copy_shape(old_output, min_node.output[0])
682-
nodes.extend([node, min_node])
683-
return nodes
684-
685-
686-
def relu6_op8(ctx, node, name, args):
687-
# relu6 = min(max(features, 0), 6) for opset >= 8
688-
# since onnx does not have relu6, compose it with multiple ops.
689-
old_output = node.output[0]
690-
dtype = ctx.get_dtype(node.input[0])
691-
dtype = utils.ONNX_TO_NUMPY_DTYPE[dtype] if dtype else np.float32
692-
node.type = "Max"
693-
nodes = []
694-
# const tensor 6
695-
six_name = utils.make_name(node.name)
696-
nodes.append(ctx.make_const(six_name, np.array([6], dtype=dtype)))
697-
zero_name = utils.make_name(node.name)
698-
nodes.append(ctx.make_const(zero_name, np.array([0], dtype=dtype)))
699-
node.input.append(zero_name)
700-
min_name = utils.make_name(node.name)
701-
min_node = ctx.insert_new_node_on_output("Min", node.output[0], name=min_name)
702-
min_node.input.append(six_name)
703-
ctx.copy_shape(old_output, min_node.output[0])
704-
nodes.extend([node, min_node])
705-
return nodes
637+
node.type = "Relu"
638+
clip_name = utils.make_name(node.name)
639+
clip_node = ctx.insert_new_node_on_output("Clip", node.output[0], name=clip_name, min=0.0, max=6.0)
640+
ctx.copy_shape(node.output[0], clip_node.output[0])
641+
return [node, clip_node]
706642

707643

708644
def squareddifference_op(ctx, node, name, args):
@@ -1901,7 +1837,6 @@ def where_op(ctx, node, name, args):
19011837
}
19021838

19031839
_OPSET_8 = {
1904-
"Relu6": (relu6_op8, []), # make use of min/max broadcast
19051840
"ReverseSequence": (reverse_op8, []), # make use of scan
19061841
"Select": (select_op8, []),
19071842
}

tools/onnx-experiments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def get_args():
3838

3939
def load_graph(fname):
4040
model_proto = onnx.ModelProto()
41+
with open(fname, "rb") as f:
42+
data = f.read()
43+
model_proto.ParseFromString(data)
4144
g = GraphUtil.create_graph_from_onnx_model(model_proto)
4245
return g, model_proto.producer_name
4346

0 commit comments

Comments
 (0)