Skip to content

Commit 607d780

Browse files
authored
Merge pull request #318 from lucienwang1009/dropout
set shape and dtype for RandomNormal and Dropout ops
2 parents 0fbb359 + cce9bd0 commit 607d780

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,21 @@ def test_depthwiseconv_1(self):
376376
# rtol is a bit high, 2 values have a bit high error. Maybe use different input data.
377377
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.08)
378378

379+
def test_dropout(self):
380+
is_training = tf.placeholder_with_default(False, (), "is_training")
381+
x_val = np.ones([1, 24, 24, 3], dtype=np.float32)
382+
# Define a scope for reusing the variables
383+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_1")
384+
x_ = tf.identity(x)
385+
386+
fc1 = tf.layers.dropout(x_, rate=.1, training=is_training)
387+
388+
_ = tf.identity(fc1, name="output")
389+
feed_dict = {"input_1:0": x_val}
390+
input_names_with_port = ["input_1:0"]
391+
output_names_with_port = ["output:0"]
392+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
393+
379394
def test_conv2d_with_input_transpose(self):
380395
x_shape = [2, 32, 32, 3]
381396
kernel_shape = [3, 3, 3, 3]

tf2onnx/rewriter/random_uniform.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output):
6969
if ru_op.inputs[0].type == "Shape":
7070
shape_node = ru_op.inputs[0]
7171
new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name,
72-
attr={"low": tmin, "high": tmax, "dtype": dtype})
72+
attr={"low": tmin, "high": tmax, "dtype": dtype},
73+
shapes=shape_node.output_shapes, dtypes=[dtype])
7374
else:
7475
shape = g.get_shape(output.output[0])
7576
new_node = g.make_node("RandomUniform", [], name=op_name,
76-
attr={"low": tmin, "high": tmax, "dtype": dtype, "shape": shape})
77+
attr={"low": tmin, "high": tmax, "dtype": dtype, "shape": shape},
78+
shapes=[shape], dtypes=[dtype])
7779
return new_node

tf2onnx/tfonnx.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1962,7 +1962,15 @@ def rewrite_dropout(g, ops):
19621962
outputs = match.get_op('outputs')
19631963
op_name = utils.make_name("Dropout")
19641964
out_name = port_name(op_name)
1965-
new_node = g.make_node("Dropout", [inputs2.input[0]], outputs=[out_name], name=op_name, attr={"ratio": 1.0})
1965+
new_node = g.make_node(
1966+
"Dropout",
1967+
[inputs2.input[0]],
1968+
outputs=[out_name],
1969+
name=op_name,
1970+
attr={"ratio": 1.0},
1971+
shapes=[g.get_shape(inputs2.input[0])],
1972+
dtypes=[g.get_dtype(inputs2.input[0])]
1973+
)
19661974
ops = g.replace_subgraph(ops, match, [inputs2], [outputs], [new_node], [new_node])
19671975

19681976
return ops

0 commit comments

Comments
 (0)