Skip to content

Commit aab7878

Browse files
resolve comments
1 parent a052dec commit aab7878

File tree

10 files changed

+168
-131
lines changed

10 files changed

+168
-131
lines changed

tests/test_loops.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,6 @@
2020

2121
class LoopTests(Tf2OnnxBackendTestBase):
2222

23-
@check_tf_min_version("1.9")
24-
def test_simple_while_loop_var_shape(self):
25-
# test for while_loop with variant shape variables
26-
# may not meet ONNX Loop spec
27-
i = tf.placeholder(tf.int32, (1), name="input_1")
28-
const = tf.constant(np.array([2], dtype=np.int32))
29-
30-
c = lambda i: tf.reduce_all(tf.shape(i) < 10)
31-
b = lambda i: tf.concat([i, const], 0)
32-
r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])])
33-
34-
_ = tf.identity(r, name="output")
35-
input_names_with_port = ["input_1:0"]
36-
feed_dict = {"input_1:0": np.array([0], dtype=np.int32)}
37-
38-
output_names_with_port = ["output:0"]
39-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
40-
4123
def test_simple_while_loop(self):
4224
i = tf.placeholder(tf.int32, (), name="input_1")
4325
c = lambda i: tf.less(i, 10)
@@ -214,6 +196,24 @@ def fn1(elem):
214196
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
215197
tf.reset_default_graph()
216198

199+
@check_tf_min_version("1.9")
200+
def test_simple_while_loop_var_shape(self):
201+
# test for while_loop with variant shape variables
202+
# may not meet ONNX Loop spec
203+
i = tf.placeholder(tf.int32, (1), name="input_1")
204+
const = tf.constant(np.array([2], dtype=np.int32))
205+
206+
c = lambda i: tf.reduce_all(tf.shape(i) < 10)
207+
b = lambda i: tf.concat([i, const], 0)
208+
r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])])
209+
210+
_ = tf.identity(r, name="output")
211+
input_names_with_port = ["input_1:0"]
212+
feed_dict = {"input_1:0": np.array([0], dtype=np.int32)}
213+
214+
output_names_with_port = ["output:0"]
215+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
216+
217217

218218
if __name__ == '__main__':
219219
unittest_main()

tests/test_tf_shape_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def _run_test_case(self, input_names_with_port, output_names_with_port):
6868
def _compare_shape_for_op(self, op1, op2):
6969
"""Align outputs of op2 to op1."""
7070
for out1, out2 in zip(op1.outputs, op2.outputs):
71-
expected_shape = utils.get_shape_from_tf_output(out1)
71+
expected_shape = utils.get_tf_tensor_shape(out1)
7272
if out1 is not None:
73-
actual_shape = utils.get_shape_from_tf_output(out2)
73+
actual_shape = utils.get_tf_tensor_shape(out2)
7474
self.assertTrue(utils.are_shapes_compatible(expected_shape, actual_shape))
7575

7676
def test_while_loop_with_ta_read_and_write(self):

tf2onnx/rewriter/bigru_rewriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import numpy as np
1515
from tf2onnx import utils
16-
from tf2onnx.utils import is_reverse_op
16+
from tf2onnx.utils import is_tf_reverse_op
1717
from tf2onnx.rewriter.bilstm_rewriter import slice_bilstm_for_original_lstm_consumers,\
1818
get_reverse_nodes_after_y_output, get_np_val_for_const, _process_single_init_node
1919

@@ -143,7 +143,7 @@ def rewrite_bidirectional_grus(g, ops):
143143
input_id = temp.input[0]
144144
temp = temp.inputs[0]
145145

146-
if is_reverse_op(temp):
146+
if is_tf_reverse_op(temp):
147147
input_id = temp.input[0]
148148
is_backward_gru = True
149149

tf2onnx/rewriter/bilstm_rewriter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import numpy as np
1515
from tf2onnx import utils
16-
from tf2onnx.utils import is_reverse_op
16+
from tf2onnx.utils import is_tf_reverse_op
1717
from tf2onnx.graph_builder import GraphBuilder
1818

1919
logger = logging.getLogger(__name__)
@@ -186,7 +186,7 @@ def rewrite_bidirectional_lstms(g, ops):
186186
input_id = temp.input[0]
187187
temp = temp.inputs[0]
188188

189-
if is_reverse_op(temp):
189+
if is_tf_reverse_op(temp):
190190
input_id = temp.input[0]
191191
is_backward_lstm = True
192192

@@ -220,13 +220,13 @@ def get_reverse_nodes_after_y_output(g, lstm_bw):
220220
if len(trans_nodes) == 1:
221221
if trans_nodes[0].type == "Transpose":
222222
reverse_nodes = g.find_output_consumers(trans_nodes[0].output[0])
223-
elif is_reverse_op(trans_nodes[0]):
223+
elif is_tf_reverse_op(trans_nodes[0]):
224224
reverse_nodes = trans_nodes
225225
else:
226226
logger.debug("not found reverse op, unexpected")
227227
return None
228228

229-
are_all_reverse = all([is_reverse_op(r_op) for r_op in reverse_nodes])
229+
are_all_reverse = all([is_tf_reverse_op(r_op) for r_op in reverse_nodes])
230230
if are_all_reverse:
231231
return reverse_nodes
232232

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from collections import OrderedDict
1313
from tf2onnx import utils
1414
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
15-
from tf2onnx.utils import is_loopcond_op, is_tensor_array_op
16-
from tf2onnx.utils import is_tensor_array_gather_op, is_tensor_array_write_op
15+
from tf2onnx.utils import is_tf_loopcond_op, is_tf_tensor_array_op
16+
from tf2onnx.utils import is_tf_tensor_array_gather_op, is_tf_tensor_array_write_op
1717
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT
1818
from tf2onnx.utils import TensorValueInfo
1919

@@ -195,7 +195,7 @@ def rewrite(self, context):
195195
def run_internal(self):
196196
loopcond_ops = []
197197
for op in self.g.get_nodes():
198-
if is_loopcond_op(op):
198+
if is_tf_loopcond_op(op):
199199
loopcond_ops.append(op)
200200

201201
# self.g.get_nodes may change inside this loop so that we parse all LoopCond first
@@ -316,7 +316,7 @@ def _cut_off_connection_for_cell(self, context):
316316

317317
if val.is_tensor_array:
318318
# connect NextIteration to an invalid node, to cut off an ending node of the cell.
319-
ta_write_nodes = [n for n in self.g.get_nodes() if is_tensor_array_write_op(n)]
319+
ta_write_nodes = [n for n in self.g.get_nodes() if is_tf_tensor_array_write_op(n)]
320320
self.g.replace_all_inputs(ta_write_nodes, val.next_iteration_input.id, INVALID_INPUT_ID)
321321
else:
322322
# connect NextIteration to an invalid node, to cut off an ending node of the cell.
@@ -376,19 +376,19 @@ def _get_loop_var_from_switch(self, switch_node):
376376

377377
is_ta = False
378378
ta_index_id = None
379-
if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
379+
if is_tf_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
380380
is_ta = True
381381

382382
ta_write_node = self.g.get_node_by_output(last_iteration_output_id)
383-
utils.make_sure(is_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op")
383+
utils.make_sure(is_tf_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op")
384384
last_iteration_output_id = ta_write_node.input[2]
385385
ta_index_id = ta_write_node.input[1]
386386

387387
# here we parse patterns generated by
388388
# ta.write(), then ta.stack(), because this is the most frequent usage pattern.
389389
if exit_output_id:
390390
exit_consumers = self.g.find_output_consumers(exit_output_id)
391-
ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0]
391+
ta_gather_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op(n)][0]
392392

393393
# update exit output id, treat the gather output as ta's output
394394
exit_output_id = ta_gather_node.output[0]

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tf2onnx import utils
1515
from tf2onnx.graph_builder import GraphBuilder
1616
from tf2onnx.rewriter.rnn_utils import RNNUnitType, RnnWeight, get_weights_from_const_node
17-
from tf2onnx.utils import is_concat_op, is_slice_op
17+
from tf2onnx.utils import is_tf_concat_op, is_tf_slice_op
1818

1919
from tf2onnx.rewriter.unit_rnn_rewriter_base import UnitRnnRewriterBase
2020

@@ -169,8 +169,8 @@ def _ct_ht_shared_variable_finder(self, context):
169169
lstm_cell = context.cell_match
170170
ct = lstm_cell.get_op("ct").output[0]
171171
ht = lstm_cell.get_op("ht").output[0]
172-
ct_concat = [c for c in self.g.find_output_consumers(ct) if is_concat_op(c)]
173-
ht_concat = [c for c in self.g.find_output_consumers(ht) if is_concat_op(c)]
172+
ct_concat = [c for c in self.g.find_output_consumers(ct) if is_tf_concat_op(c)]
173+
ht_concat = [c for c in self.g.find_output_consumers(ht) if is_tf_concat_op(c)]
174174
if len(ct_concat) != 1 or len(ht_concat) != 1 or ct_concat[0] != ht_concat[0]:
175175
logger.debug("failed to find ct-ht concat")
176176
return None
@@ -179,8 +179,8 @@ def _ct_ht_shared_variable_finder(self, context):
179179
consumers = []
180180
ct_identity_consumer = lstm_cell.get_op("ct_identity_consumer")
181181
ht_identity_consumer = lstm_cell.get_op("xh")
182-
ct_slice = [c for c in ct_identity_consumer.inputs if is_slice_op(c)]
183-
ht_slice = [c for c in ht_identity_consumer.inputs if is_slice_op(c)]
182+
ct_slice = [c for c in ct_identity_consumer.inputs if is_tf_slice_op(c)]
183+
ht_slice = [c for c in ht_identity_consumer.inputs if is_tf_slice_op(c)]
184184
if len(ct_slice) != 1 or len(ht_slice) != 1:
185185
logger.debug("failed to find slice op before identity consumers")
186186
return None

tf2onnx/rewriter/unit_rnn_rewriter_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tf2onnx.rewriter.loop_rewriter_base import LoopRewriterBase, Context
1515
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT, get_pattern, \
1616
get_rnn_scope_name, parse_rnn_loop, seq_len_pattern
17-
from tf2onnx.utils import is_select_op, is_tensor_array_write_op
17+
from tf2onnx.utils import is_tf_select_op, is_tf_tensor_array_write_op
1818
from tf2onnx.graph_matcher import GraphMatcher
1919

2020

@@ -210,7 +210,7 @@ def find_sequence_length_node(self, context):
210210
# get any state variable
211211
state_variable = list(context.state_variables.values())[0]
212212
next_iter_input_node = self.g.get_node_by_output(state_variable.next_iteration_input.id)
213-
if not is_select_op(next_iter_input_node):
213+
if not is_tf_select_op(next_iter_input_node):
214214
logger.debug("no sequence length node is given")
215215
return None
216216
matcher = GraphMatcher(seq_len_pattern)
@@ -302,10 +302,10 @@ def _find_state_variable_with_select(self, context,
302302
# find all select not followed by TensorArrayWrite
303303
select = []
304304
for c in self.g.find_output_consumers(next_iteration_input):
305-
if not is_select_op(c):
305+
if not is_tf_select_op(c):
306306
continue
307307
out_ta_writer = [
308-
o for o in self.g.find_output_consumers(c.output[0]) if is_tensor_array_write_op(o)
308+
o for o in self.g.find_output_consumers(c.output[0]) if is_tf_tensor_array_write_op(o)
309309
]
310310
if out_ta_writer:
311311
continue

0 commit comments

Comments
 (0)