Skip to content

Commit d7ee792

Browse files
Misc fixes for adding keras2onnx tests (#1567)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 457a487 commit d7ee792

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

tf2onnx/onnx_opset/reduction.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def version_6(cls, ctx, node, **kwargs):
120120
reduce_dim = [reduce_dim]
121121

122122
if ctx.opset < 11:
123+
inp_rank = ctx.get_rank(node.input[0])
124+
if inp_rank is not None:
125+
reduce_dim = [d + inp_rank if d < 0 else d for d in reduce_dim]
123126
utils.make_sure(all(i >= 0 for i in reduce_dim), "negative reduce axis is not supported in onnx for now")
124127

125128
cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT})

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,8 @@ def version_1(cls, ctx, node, **kwargs):
856856
node = GraphBuilder(ctx).make_slice(
857857
kwargs, name=node.name, dtypes=out_dtypes, shapes=out_shapes, return_node=True)
858858
else:
859-
node = ctx.make_node("Identity", [node.input[0]], name=node.name, dtypes=out_dtypes, shapes=out_shapes)
859+
node = ctx.make_node("Identity", [node.input[0]], name=node.name, outputs=node.output,
860+
dtypes=out_dtypes, shapes=out_shapes)
860861
nodes = [node]
861862
if needs_squeeze:
862863
# insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs)

tf2onnx/rewriter/dropout_rewriter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def rewrite_dropout(g, ops):
6262
ratio = input3.get_tensor_value()
6363

6464
if input2.inputs[0].is_scalar():
65-
data = input2.inputs[1]
65+
data_output = input2.input[1]
6666
scaling_constant = input2.inputs[0].get_tensor_value()
6767
elif input2.inputs[1].is_scalar():
68-
data = input2.inputs[0]
68+
data_output = input2.input[0]
6969
scaling_constant = input2.inputs[1].get_tensor_value()
7070
else:
7171
logger.warning("Could not find scaling constant for dropout pattern rooted at %s. "
@@ -89,12 +89,12 @@ def rewrite_dropout(g, ops):
8989
out_name = utils.port_name(op_name)
9090
new_node = g.make_node(
9191
"Dropout",
92-
inputs=[data.output[0]],
92+
inputs=[data_output],
9393
outputs=[out_name],
9494
name=op_name,
9595
attr={"ratio": ratio},
96-
shapes=[g.get_shape(data.output[0])],
97-
dtypes=[g.get_dtype(data.output[0])]
96+
shapes=[g.get_shape(data_output)],
97+
dtypes=[g.get_dtype(data_output)]
9898
)
9999
g.replace_all_inputs(outputs.output[0], new_node.output[0], ops=ops)
100100
for n in nodes_to_remove:

tf2onnx/rewriter/random_uniform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
7272
dtype = g.get_dtype(output.output[0])
7373
op_name = utils.make_name("RandomUniform")
7474
shape_node = ru_op.inputs[0]
75+
shape_node_output = ru_op.input[0]
7576
shape = g.get_shape(output.output[0])
7677
if shape_node.is_const():
7778
# if the tensorflow input (aka the shape) is const we can use the RandomUniform op
@@ -103,7 +104,7 @@ def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
103104
to_delete.remove(shape_node)
104105
# create a fill op with the shape of the value of the input tensor
105106
zero = g.make_const(utils.make_name("zero"), np.zeros((), dtype=np.float32))
106-
fill_node = g.make_node("Fill", inputs=[shape_node.output[0], zero.name],
107+
fill_node = g.make_node("Fill", inputs=[shape_node_output, zero.name],
107108
shapes=[shape], dtypes=[dtype])
108109
func, _ = handler.tf_op.find_effective_op("Fill")
109110
func(g, fill_node)

0 commit comments

Comments
 (0)