Skip to content

Commit e0fa7bd

Browse files
add unittest
1 parent c7ef119 commit e0fa7bd

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

tests/backend_test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
130130
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
131131
sess.graph_def, constant_fold)
132132

133-
if self.config.is_debug_mode and constant_fold:
133+
if self.config.is_debug_mode:
134134
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
135135
utils.save_protobuf(model_path, graph_def)
136136
self.log.debug("created file %s", model_path)

tests/run_pretrained_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
242242
inputs[k] = v
243243

244244
graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const)
245+
if debug:
246+
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
245247
shape_override = {}
246248
g = tf.import_graph_def(graph_def, name='')
247249
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,21 @@ def test_dropout(self):
353353
output_names_with_port = ["output:0"]
354354
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
355355

356+
def test_nn_dropout(self):
357+
keep_prob = tf.placeholder_with_default(1., (), "keep_prob")
358+
x_val = np.ones([1, 24, 24, 3], dtype=np.float32)
359+
# Define a scope for reusing the variables
360+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_1")
361+
x_ = tf.identity(x)
362+
363+
fc1 = tf.nn.dropout(x_, keep_prob)
364+
365+
_ = tf.identity(fc1, name="output")
366+
feed_dict = {"input_1:0": x_val}
367+
input_names_with_port = ["input_1:0"]
368+
output_names_with_port = ["output:0"]
369+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False)
370+
356371
def test_conv2d_with_input_transpose(self):
357372
x_shape = [2, 32, 32, 3]
358373
kernel_shape = [3, 3, 3, 3]

0 commit comments

Comments
 (0)