|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | + |
| 4 | +""" |
| 5 | +tf2onnx.rewriter.lstm_tf2_rewriter - Rewrites LSTM pattern used by tf2. |
| 6 | +""" |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from tf2onnx.graph_matcher import GraphMatcher |
| 10 | +from tf2onnx.rewriter.rnn_utils import make_lstmcell_pattern |
| 11 | +from tf2onnx.tf_loader import find_function |
| 12 | +from tf2onnx.rewriter.lstm_rewriter_base import LSTMContext |
| 13 | +from tf2onnx.rewriter.lstm_rewriter import LSTMRewriter |
| 14 | +from tf2onnx.graph_builder import GraphBuilder |
| 15 | +from tf2onnx import utils |
| 16 | + |
| 17 | +# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable |
| 18 | + |
| 19 | + |
| 20 | +def rewriter_lstm_tf2(g, ops): |
| 21 | + pattern1 = make_lstmcell_pattern("Identity") |
| 22 | + |
| 23 | + for pattern in [pattern1]: |
| 24 | + matcher = GraphMatcher(pattern, allow_reorder=False) |
| 25 | + match_results = list(matcher.match_ops(ops)) |
| 26 | + for match_result in match_results: |
| 27 | + concat = match_result.get_op("xh") |
| 28 | + if len(concat.inputs) != 3: |
| 29 | + continue |
| 30 | + get_item = concat.inputs[0] |
| 31 | + if not get_item.type == "TensorListGetItem": |
| 32 | + continue |
| 33 | + x_e = get_item.inputs[0] |
| 34 | + if not x_e.is_graph_input(): |
| 35 | + continue |
| 36 | + x_idx = g.input_names.index(x_e.output[0]) |
| 37 | + |
| 38 | + ht_mul = match_result.get_op("ht") |
| 39 | + final_consumers = g.find_output_consumers(ht_mul.output[0]) |
| 40 | + select_ops = [n for n in final_consumers if n.type == "Select"] |
| 41 | + def has_tensor_list_consumer(n): |
| 42 | + return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0])) |
| 43 | + select_ops = [n for n in select_ops if has_tensor_list_consumer(n)] |
| 44 | + if len(select_ops) == 1: |
| 45 | + greater_eq = select_ops[0].inputs[0] |
| 46 | + if greater_eq.type != "GreaterEqual": |
| 47 | + continue |
| 48 | + seq_len = greater_eq.inputs[1] |
| 49 | + if not seq_len.is_graph_input(): |
| 50 | + continue |
| 51 | + seq_len_idx = g.input_names.index(seq_len.output[0]) |
| 52 | + final_consumers = g.find_output_consumers(select_ops[0].output[0]) |
| 53 | + else: |
| 54 | + seq_len_idx = None |
| 55 | + |
| 56 | + tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"] |
| 57 | + if len(tensor_set_items) != 1: |
| 58 | + continue |
| 59 | + |
| 60 | + if not tensor_set_items[0].inputs[0].is_graph_input(): |
| 61 | + continue |
| 62 | + out_idx = g.input_names.index(tensor_set_items[0].input[0]) |
| 63 | + |
| 64 | + if concat.inputs[1].is_graph_input(): |
| 65 | + # c and h are separate |
| 66 | + h_idx = g.input_names.index(concat.input[1]) |
| 67 | + c_e = match_result.get_op("c") |
| 68 | + if not c_e.is_graph_input(): |
| 69 | + continue |
| 70 | + c_idx = g.input_names.index(c_e.output[0]) |
| 71 | + ch_info = { |
| 72 | + "state_is_tuple": True, |
| 73 | + "c_idx": c_idx, |
| 74 | + "h_idx": h_idx, |
| 75 | + } |
| 76 | + else: |
| 77 | + # c and h are concatenated |
| 78 | + if not concat.inputs[1].type == "Slice": |
| 79 | + continue |
| 80 | + ch_e = concat.inputs[1].inputs[0] |
| 81 | + if not ch_e.is_graph_input(): |
| 82 | + continue |
| 83 | + ch_idx = g.input_names.index(ch_e.output[0]) |
| 84 | + |
| 85 | + c_e = match_result.get_op("c") |
| 86 | + if not c_e.type == "Slice" or c_e.input[0] != ch_e.output[0]: |
| 87 | + continue |
| 88 | + ch_info = { |
| 89 | + "state_is_tuple": False, |
| 90 | + "ch_idx": ch_idx, |
| 91 | + } |
| 92 | + |
| 93 | + w_e = match_result.get_op("cell_kernel") |
| 94 | + if not w_e.is_graph_input(): |
| 95 | + continue |
| 96 | + w_idx = g.input_names.index(w_e.output[0]) |
| 97 | + |
| 98 | + bias_add = match_result.get_op("bias_add") |
| 99 | + if bias_add is not None and bias_add.data_format != "NHWC": |
| 100 | + continue |
| 101 | + |
| 102 | + b_e = match_result.get_op("cell_bias") |
| 103 | + if not b_e.is_graph_input(): |
| 104 | + continue |
| 105 | + b_idx = g.input_names.index(b_e.output[0]) |
| 106 | + |
| 107 | + ft_bias_node = match_result.get_op("ft_bias") |
| 108 | + if not ft_bias_node.is_const(): |
| 109 | + continue |
| 110 | + if g.get_dtype(ft_bias_node.output[0]) != g.get_dtype(b_e.output[0]): |
| 111 | + continue |
| 112 | + ft_bias = ft_bias_node.get_tensor_value(as_list=False) |
| 113 | + |
| 114 | + g.lstm_rewriter_context = { |
| 115 | + "x_idx": x_idx, |
| 116 | + "out_idx": out_idx, |
| 117 | + "weight_idx": w_idx, |
| 118 | + "bias_idx": b_idx, |
| 119 | + "ft_bias": ft_bias, |
| 120 | + "seq_len_idx": seq_len_idx, |
| 121 | + **ch_info |
| 122 | + } |
| 123 | + |
| 124 | + for op in ops: |
| 125 | + if op.is_while(): |
| 126 | + body_graph = find_function(op.get_attr_str("body")) |
| 127 | + if body_graph.lstm_rewriter_context is None: |
| 128 | + continue |
| 129 | + body_context = body_graph.lstm_rewriter_context |
| 130 | + w = op.input[body_context["weight_idx"]] |
| 131 | + b = op.input[body_context["bias_idx"]] |
| 132 | + if not g.is_const(w) or not g.is_const(b): |
| 133 | + continue |
| 134 | + w_const = g.get_tensor_value(w, as_list=False) |
| 135 | + b_const = g.get_tensor_value(b, as_list=False) |
| 136 | + |
| 137 | + if body_context["state_is_tuple"]: |
| 138 | + initial_c_sq = op.input[body_context["c_idx"]] |
| 139 | + initial_h_sq = op.input[body_context["h_idx"]] |
| 140 | + initial_c = GraphBuilder(g).make_unsqueeze({"data": initial_c_sq, "axes": [0]}) |
| 141 | + initial_h = GraphBuilder(g).make_unsqueeze({"data": initial_h_sq, "axes": [0]}) |
| 142 | + else: |
| 143 | + initial_ch = op.input[body_context["ch_idx"]] |
| 144 | + if not g.is_const(initial_ch): |
| 145 | + continue |
| 146 | + initial_ch_const = g.get_tensor_value(initial_ch, as_list=False) |
| 147 | + if not len(initial_ch_const.shape) == 2: |
| 148 | + continue |
| 149 | + initial_ch_const = np.expand_dims(initial_ch_const, axis=0) |
| 150 | + initial_c_const, initial_h_const = np.split(initial_ch_const, 2, axis=2) |
| 151 | + initial_c = g.make_const(utils.make_name("initial_c"), initial_c_const).output[0] |
| 152 | + initial_h = g.make_const(utils.make_name("initial_h"), initial_h_const).output[0] |
| 153 | + |
| 154 | + context = LSTMContext() |
| 155 | + context.weights.append({"weight": w_const, "bias": b_const, "ft_bias": body_context["ft_bias"]}) |
| 156 | + context.onnx_input_ids.append({}) |
| 157 | + context.input_size.append(None) |
| 158 | + context.hidden_size.append(None) |
| 159 | + context.attributes.append({}) |
| 160 | + tensor_array_inp = op.inputs[body_context["x_idx"]] |
| 161 | + if not tensor_array_inp.type == "TensorListFromTensor": |
| 162 | + continue |
| 163 | + |
| 164 | + final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]]) |
| 165 | + output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"] |
| 166 | + |
| 167 | + context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0] |
| 168 | + if body_context["seq_len_idx"] is None: |
| 169 | + context.onnx_input_ids[0]["sequence_lens"] = "" |
| 170 | + else: |
| 171 | + context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]] |
| 172 | + context.onnx_input_ids[0]["initial_c"] = initial_c |
| 173 | + context.onnx_input_ids[0]["initial_h"] = initial_h |
| 174 | + |
| 175 | + lstm_rewriter = LSTMRewriter(g) |
| 176 | + lstm_rewriter.num_lstm_layers = 1 |
| 177 | + lstm_rewriter.process_weights_and_bias(context) |
| 178 | + lstm_node = lstm_rewriter.create_rnn_node(context)[0] |
| 179 | + squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]}) |
| 180 | + for output in output_ys: |
| 181 | + g.replace_all_inputs(output, squeeze_output) |
| 182 | + |
| 183 | + if body_context["state_is_tuple"]: |
| 184 | + c_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[2], "axes": [0]}) |
| 185 | + h_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[1], "axes": [0]}) |
| 186 | + g.replace_all_inputs(op.output[body_context["c_idx"]], c_squeeze) |
| 187 | + g.replace_all_inputs(op.output[body_context["h_idx"]], h_squeeze) |
| 188 | + else: |
| 189 | + concat_ch = g.make_node("Concat", [lstm_node.output[2], lstm_node.output[1]], |
| 190 | + attr={"axis": 2}).output[0] |
| 191 | + ch_squeeze = GraphBuilder(g).make_squeeze({"data": concat_ch, "axes": [0]}) |
| 192 | + ch_output = op.output[body_context["ch_idx"]] |
| 193 | + g.replace_all_inputs(ch_output, ch_squeeze) |
| 194 | + |
| 195 | + return g.get_nodes() |
0 commit comments