|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +"""Unit Tests for Tensorflow shape inference.""" |
| 5 | + |
| 6 | +from __future__ import division |
| 7 | +from __future__ import print_function |
| 8 | +from __future__ import unicode_literals |
| 9 | + |
| 10 | +import os |
| 11 | +import numpy as np |
| 12 | +import tensorflow as tf |
| 13 | + |
| 14 | +from tensorflow.python.ops import variables as variables_lib |
| 15 | +from tensorflow.python.ops import init_ops |
| 16 | + |
| 17 | +from backend_test_base import Tf2OnnxBackendTestBase |
| 18 | +from common import * # pylint: disable=wildcard-import, unused-wildcard-import |
| 19 | +from tf2onnx import utils |
| 20 | +from tf2onnx.tfonnx import tf_optimize |
| 21 | +from tf2onnx.shape_inference import infer_shape_for_graph |
| 22 | + |
| 23 | +# pylint: disable=missing-docstring |
| 24 | + |
| 25 | + |
| 26 | +class TFShapeInferenceTests(Tf2OnnxBackendTestBase): |
| 27 | + def _run_test_case(self, input_names_with_port, output_names_with_port): |
| 28 | + graph_def = None |
| 29 | + with tf.Session() as sess: |
| 30 | + # freeze graph |
| 31 | + origin_graph = sess.graph |
| 32 | + variables_lib.global_variables_initializer().run() |
| 33 | + output_name_without_port = [n.split(':')[0] for n in output_names_with_port] |
| 34 | + graph_def = tf.graph_util.convert_variables_to_constants( |
| 35 | + sess, sess.graph_def, |
| 36 | + output_name_without_port |
| 37 | + ) |
| 38 | + |
| 39 | + tf.reset_default_graph() |
| 40 | + tf.import_graph_def(graph_def, name='') |
| 41 | + |
| 42 | + # optimize graph |
| 43 | + graph_def = tf_optimize(input_names_with_port, output_names_with_port, |
| 44 | + sess.graph_def, True) |
| 45 | + |
| 46 | + with tf.Session() as sess: |
| 47 | + if self.config.is_debug_mode: |
| 48 | + if not os.path.exists(self.test_data_directory): |
| 49 | + os.makedirs(self.test_data_directory) |
| 50 | + model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") |
| 51 | + utils.save_protobuf(model_path, graph_def) |
| 52 | + self.logger.debug("created file %s", model_path) |
| 53 | + |
| 54 | + tf.reset_default_graph() |
| 55 | + tf.import_graph_def(graph_def, name='') |
| 56 | + |
| 57 | + with tf.Session() as sess: |
| 58 | + inferred_graph = infer_shape_for_graph(sess.graph) |
| 59 | + # compare each operation |
| 60 | + for op in origin_graph.get_operations(): |
| 61 | + inferred_op = None |
| 62 | + try: |
| 63 | + inferred_op = inferred_graph.get_operation_by_name(op.name) |
| 64 | + except KeyError: |
| 65 | + continue |
| 66 | + self._compare_shape_for_op(op, inferred_op) |
| 67 | + |
| 68 | + def _compare_shape_for_op(self, op1, op2): |
| 69 | + """Align outputs of op2 to op1.""" |
| 70 | + for out1, out2 in zip(op1.outputs, op2.outputs): |
| 71 | + expected_shape = utils.get_tf_tensor_shape(out1) |
| 72 | + if out1 is not None: |
| 73 | + actual_shape = utils.get_tf_tensor_shape(out2) |
| 74 | + self.assertTrue(utils.are_shapes_compatible(expected_shape, actual_shape)) |
| 75 | + |
| 76 | + def test_while_loop_with_ta_read_and_write(self): |
| 77 | + i = tf.placeholder(tf.int32, (), name="input_1") |
| 78 | + inputs = tf.placeholder(tf.float32, (10,), name="input_2") |
| 79 | + |
| 80 | + inputs_2 = tf.identity(inputs) |
| 81 | + input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs_2) |
| 82 | + output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True) |
| 83 | + |
| 84 | + c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0) |
| 85 | + |
| 86 | + def b(i, out_ta): |
| 87 | + new_i = tf.add(i, 1) |
| 88 | + x = input_ta.read(i) |
| 89 | + x = x + 3 |
| 90 | + out_ta_new = out_ta.write(i, x) |
| 91 | + return new_i, out_ta_new |
| 92 | + |
| 93 | + i_final, out_final = tf.while_loop(c, b, [i, output_ta]) |
| 94 | + _ = tf.identity(i_final, name="i") |
| 95 | + _ = tf.identity(out_final.stack(), name="output_ta") |
| 96 | + input_names_with_port = ["input_1:0", "input_2:0"] |
| 97 | + |
| 98 | + output_names_with_port = ["i:0", "output_ta:0"] |
| 99 | + self._run_test_case(input_names_with_port, output_names_with_port) |
| 100 | + |
| 101 | + def test_map_fn(self): |
| 102 | + def fn0(elem): |
| 103 | + res = elem + elem * elem |
| 104 | + return res |
| 105 | + |
| 106 | + def fn1(elem): |
| 107 | + res = elem[0] * elem[1] + elem[0] |
| 108 | + return res |
| 109 | + |
| 110 | + x_val = 100 * np.random.random_sample([2, 10]).astype(np.float32) |
| 111 | + y_val = 100 * np.random.random_sample([2, 10]).astype(np.float32) |
| 112 | + |
| 113 | + # test fn0 |
| 114 | + x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0") |
| 115 | + x_ = tf.identity(x) |
| 116 | + res_ = tf.map_fn(fn0, x_, dtype=tf.float32) |
| 117 | + _ = tf.identity(res_, name="output_0") |
| 118 | + input_names_with_port = ["input_0:0"] |
| 119 | + output_names_with_port = ["output_0:0"] |
| 120 | + self._run_test_case(input_names_with_port, output_names_with_port) |
| 121 | + tf.reset_default_graph() |
| 122 | + |
| 123 | + # test fn1 |
| 124 | + x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0") |
| 125 | + y = tf.placeholder(tf.float32, shape=y_val.shape, name="input_1") |
| 126 | + x_ = tf.identity(x) |
| 127 | + y_ = tf.identity(y) |
| 128 | + res_ = tf.map_fn(fn1, (x_, y_), dtype=tf.float32) |
| 129 | + _ = tf.identity(res_, name="output_0") |
| 130 | + input_names_with_port = ["input_0:0", "input_1:0"] |
| 131 | + output_names_with_port = ["output_0:0"] |
| 132 | + self._run_test_case(input_names_with_port, output_names_with_port) |
| 133 | + |
| 134 | + def test_bidrectional_attention_wrapper_lstm_encoder(self): |
| 135 | + size = 30 |
| 136 | + time_step = 3 |
| 137 | + input_size = 4 |
| 138 | + attn_size = size |
| 139 | + batch_size = 9 |
| 140 | + |
| 141 | + # shape [batch size, time step, size] |
| 142 | + # attention_state: usually the output of an RNN encoder. |
| 143 | + # This tensor should be shaped `[batch_size, max_time, ...]` |
| 144 | + encoder_time_step = time_step |
| 145 | + encoder_x_val = np.random.randn(encoder_time_step, input_size).astype('f') |
| 146 | + encoder_x_val = np.stack([encoder_x_val] * batch_size) |
| 147 | + encoder_x = tf.placeholder(tf.float32, encoder_x_val.shape, name="input_1") |
| 148 | + encoder_cell = tf.nn.rnn_cell.LSTMCell(size) |
| 149 | + attention_states, _ = tf.nn.dynamic_rnn(encoder_cell, encoder_x, dtype=tf.float32) |
| 150 | + # [9, 3, 30], [9, 30] |
| 151 | + attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attn_size, |
| 152 | + attention_states) |
| 153 | + |
| 154 | + match_input_fn = lambda curr_input, state: tf.concat([curr_input, state], axis=-1) |
| 155 | + cell = tf.nn.rnn_cell.LSTMCell(size) |
| 156 | + match_cell_fw = tf.contrib.seq2seq.AttentionWrapper(cell, |
| 157 | + attention_mechanism, |
| 158 | + attention_layer_size=attn_size, |
| 159 | + cell_input_fn=match_input_fn, |
| 160 | + output_attention=False) |
| 161 | + match_cell_bk = tf.contrib.seq2seq.AttentionWrapper(cell, |
| 162 | + attention_mechanism, |
| 163 | + attention_layer_size=attn_size, |
| 164 | + cell_input_fn=match_input_fn, |
| 165 | + output_attention=False) |
| 166 | + |
| 167 | + decoder_time_step = 6 |
| 168 | + decoder_x_val = np.random.randn(decoder_time_step, batch_size, input_size).astype('f') |
| 169 | + |
| 170 | + decoder_x = tf.placeholder(tf.float32, decoder_x_val.shape, name="input_2") |
| 171 | + seq_length = tf.placeholder(tf.int32, (batch_size), name="input_3") |
| 172 | + (match_output_fw, match_output_bk), (match_state_fw, match_state_bk) = \ |
| 173 | + tf.nn.bidirectional_dynamic_rnn(cell_fw=match_cell_fw, |
| 174 | + cell_bw=match_cell_bk, |
| 175 | + inputs=decoder_x, |
| 176 | + sequence_length=tf.identity(seq_length), |
| 177 | + dtype=tf.float32, |
| 178 | + time_major=True) |
| 179 | + |
| 180 | + matched_output = tf.concat([match_output_fw, match_output_bk], axis=-1) |
| 181 | + matched_state = tf.concat([match_state_fw.cell_state, match_state_bk.cell_state], -1) |
| 182 | + |
| 183 | + _ = tf.identity(matched_output, name="output_0") |
| 184 | + _ = tf.identity(matched_state, name="final_state") |
| 185 | + |
| 186 | + input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"] |
| 187 | + output_names_with_port = ["output_0:0", "final_state:0"] |
| 188 | + self._run_test_case(input_names_with_port, output_names_with_port) |
| 189 | + |
| 190 | + def test_dynamic_decode_normal_stop(self): |
| 191 | + batch_size = 2 |
| 192 | + num_units = 4 |
| 193 | + vocab_size = 5 |
| 194 | + embedding_size = 3 |
| 195 | + go_token = 0 |
| 196 | + end_token = 1 |
| 197 | + |
| 198 | + embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32)) |
| 199 | + state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)], |
| 200 | + [batch_size, num_units]) |
| 201 | + encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val) |
| 202 | + |
| 203 | + cell_initializer = init_ops.constant_initializer( |
| 204 | + np.array([[-0.9592235, 0.42451382, 0.7437744, -0.54485345, -0.80763197, |
| 205 | + 0.19663906, -0.22738314, 0.7762785, 0.7464578, 0.27227187, |
| 206 | + 0.7661047, 0.3596425, -0.8528242, -0.89316916, -0.48946142, |
| 207 | + 0.87882376], |
| 208 | + [0.86586094, -0.75018406, 0.25992537, -0.69368935, 0.2515502, |
| 209 | + -0.26379275, 0.8954313, 0.5759742, -0.7753072, -0.4388857, |
| 210 | + 0.95751476, -0.82085776, -0.9467752, -0.37055635, -0.18570113, |
| 211 | + -0.86504984], |
| 212 | + [0.02305841, 0.3850248, 0.893692, -0.6866486, -0.83703446, |
| 213 | + -0.9828961, 0.3989377, -0.59993076, 0.5330808, 0.6916566, |
| 214 | + 0.98468065, -0.6047034, 0.10823512, 0.34599304, -0.7834821, |
| 215 | + -0.7852347], |
| 216 | + [0.81643987, 0.31507468, -0.51369476, -0.12273741, 0.9701307, |
| 217 | + -0.79669356, -0.34496522, -0.88750815, -0.17995334, 0.34707904, |
| 218 | + -0.09201193, 0.5363934, -0.87229705, -0.5073328, -0.95894027, |
| 219 | + 0.5481839], |
| 220 | + [-0.84093595, -0.2341497, -0.86047816, 0.43370056, -0.39073753, |
| 221 | + 0.37730122, 0.48026466, 0.3004985, -0.60727096, 0.9043884, |
| 222 | + -0.37619448, 0.22490788, -0.03739262, 0.61672115, 0.478899, |
| 223 | + -0.40780973], |
| 224 | + [0.31202435, -0.22045255, -0.6087918, 0.95115066, 0.00199413, |
| 225 | + -0.688287, -0.1103518, 0.4169519, 0.7913246, -0.9844644, |
| 226 | + -0.6193857, 0.38659644, -0.4726901, -0.44781208, -0.5174744, |
| 227 | + -0.605911], |
| 228 | + [0.66771054, 0.34912825, 0.22297978, -0.4990945, 0.24057317, |
| 229 | + -0.5540829, 0.92277217, 0.74939895, -0.35278273, -0.21587133, |
| 230 | + -0.28613377, -0.8794241, -0.40119147, 0.67175174, -0.22741508, |
| 231 | + 0.37898326]], dtype=np.float32)) |
| 232 | + dense_initializer = init_ops.constant_initializer( |
| 233 | + np.array([[0.56177187, -0.6233454, 0.73997784, 0.35032558, 0.6479795], |
| 234 | + [0.6831174, -0.34233975, 0.39330363, 0.45177555, -0.49649096], |
| 235 | + [-0.98890066, 0.6175642, 0.09800482, -0.6721206, 0.48805737], |
| 236 | + [0.19671416, 0.2623148, 0.742548, 0.13555217, 0.56009054]], dtype=np.float32)) |
| 237 | + |
| 238 | + cell = tf.nn.rnn_cell.LSTMCell( |
| 239 | + num_units=num_units, |
| 240 | + initializer=cell_initializer, |
| 241 | + state_is_tuple=True) |
| 242 | + |
| 243 | + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( |
| 244 | + embedding=embedding, |
| 245 | + start_tokens=tf.tile([go_token], [batch_size]), |
| 246 | + end_token=end_token) |
| 247 | + |
| 248 | + output_layer = tf.layers.Dense(vocab_size, kernel_initializer=dense_initializer) |
| 249 | + decoder = tf.contrib.seq2seq.BasicDecoder( |
| 250 | + cell=cell, |
| 251 | + helper=helper, |
| 252 | + initial_state=encoder_state, |
| 253 | + output_layer=output_layer) |
| 254 | + |
| 255 | + outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode( |
| 256 | + decoder=decoder, |
| 257 | + maximum_iterations=6) |
| 258 | + |
| 259 | + _ = tf.identity(outputs.rnn_output, name="rnn_output") |
| 260 | + _ = tf.identity(outputs.sample_id, name="sample_id") |
| 261 | + _ = tf.identity(state, name="state") |
| 262 | + _ = tf.identity(sequence_lengths, name="sequence_lengths") |
| 263 | + |
| 264 | + output_names_with_port = [ |
| 265 | + "rnn_output:0", |
| 266 | + # "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1 |
| 267 | + "state:0", |
| 268 | + ] |
| 269 | + |
| 270 | + self._run_test_case([], output_names_with_port) |
| 271 | + |
| 272 | + def test_while_loop_in_cond(self): |
| 273 | + x_val = np.array([1, 2, 3], dtype=np.float32) |
| 274 | + y_val = np.array([4, 5, 6], dtype=np.float32) |
| 275 | + x = tf.placeholder(tf.float32, x_val.shape, name="input_1") |
| 276 | + y = tf.placeholder(tf.float32, y_val.shape, name="input_2") |
| 277 | + |
| 278 | + def cond_graph(): |
| 279 | + b = tf.constant(np.array([0], dtype=np.int32), dtype=tf.int32) |
| 280 | + # while_loop |
| 281 | + c = lambda y: tf.reduce_any(tf.less(y, 10)) |
| 282 | + b = lambda i: tf.add(y, 1) |
| 283 | + return tf.while_loop(c, b, [y]) |
| 284 | + |
| 285 | + res = tf.cond(x[0] < y[0], lambda: x, cond_graph, name="test_cond") |
| 286 | + _ = tf.identity(res, name="output") |
| 287 | + |
| 288 | + input_names_with_port = ["input_1:0", "input_2:0"] |
| 289 | + output_names_with_port = ["output:0"] |
| 290 | + self._run_test_case(input_names_with_port, output_names_with_port) |
| 291 | + |
| 292 | + def test_cond_in_while_loop(self): |
| 293 | + i = tf.placeholder(tf.int32, (), name="input_1") |
| 294 | + inputs = tf.placeholder(tf.float32, (10,), name="input_2") |
| 295 | + |
| 296 | + inputs_2 = tf.identity(inputs) |
| 297 | + input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs_2) |
| 298 | + output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True) |
| 299 | + |
| 300 | + c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0) |
| 301 | + |
| 302 | + def b(i, out_ta): |
| 303 | + new_i = tf.add(i, 1) |
| 304 | + x = input_ta.read(i) |
| 305 | + x = tf.cond(x > 0, lambda: x - 1, lambda: x + 3) |
| 306 | + out_ta_new = out_ta.write(i, x) |
| 307 | + return new_i, out_ta_new |
| 308 | + |
| 309 | + i_final, out_final = tf.while_loop(c, b, [i, output_ta]) |
| 310 | + _ = tf.identity(i_final, name="i") |
| 311 | + _ = tf.identity(out_final.stack(), name="output_ta") |
| 312 | + input_names_with_port = ["input_1:0", "input_2:0"] |
| 313 | + |
| 314 | + output_names_with_port = ["i:0", "output_ta:0"] |
| 315 | + self._run_test_case(input_names_with_port, output_names_with_port) |
| 316 | + |
| 317 | + |
| 318 | +if __name__ == "__main__": |
| 319 | + unittest_main() |
0 commit comments