Skip to content

Commit 2263e91

Browse files
committed
refactor
move is_xx_op from rnn_utils to utils
1 parent a2c46bd commit 2263e91

File tree

8 files changed

+41
-43
lines changed

8 files changed

+41
-43
lines changed

tf2onnx/rewriter/bigru_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import numpy as np
1515
from tf2onnx import utils
16-
from tf2onnx.rewriter.rnn_utils import is_reverse_op
16+
from tf2onnx.utils import is_reverse_op
1717
from tf2onnx.rewriter.bilstm_rewriter import slice_bilstm_for_original_lstm_consumers,\
1818
get_reverse_nodes_after_y_output, get_np_val_for_const, _process_single_init_node
1919

tf2onnx/rewriter/bilstm_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import numpy as np
1515
from tf2onnx import utils
16-
from tf2onnx.rewriter.rnn_utils import is_reverse_op
16+
from tf2onnx.utils import is_reverse_op
1717
from tf2onnx.graph_builder import GraphBuilder
1818

1919
logger = logging.getLogger(__name__)

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from collections import OrderedDict
1313
from tf2onnx import utils
1414
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
15-
from tf2onnx.rewriter.rnn_utils import is_loopcond_op, is_tensor_array_op
16-
from tf2onnx.rewriter.rnn_utils import is_tensor_array_gather_op, is_tensor_array_write_op
15+
from tf2onnx.utils import is_loopcond_op, is_tensor_array_op
16+
from tf2onnx.utils import is_tensor_array_gather_op, is_tensor_array_write_op
1717
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT
1818
from tf2onnx.utils import TensorValueInfo
1919

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import numpy as np
1414
from tf2onnx import utils
1515
from tf2onnx.graph_builder import GraphBuilder
16-
from tf2onnx.rewriter.rnn_utils import RNNUnitType, RnnWeight, \
17-
is_concat_op, is_slice_op, get_weights_from_const_node
16+
from tf2onnx.rewriter.rnn_utils import RNNUnitType, RnnWeight, get_weights_from_const_node
17+
from tf2onnx.utils import is_concat_op, is_slice_op
1818

1919
from tf2onnx.rewriter.unit_rnn_rewriter_base import UnitRnnRewriterBase
2020

tf2onnx/rewriter/rnn_utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -263,35 +263,3 @@ def get_weights_from_const_node(g, node):
263263
return None
264264

265265
return RnnWeight(node, val, dtype)
266-
267-
268-
def is_reverse_op(op):
269-
return op.type in ("ReverseV2", "ReverseSequence")
270-
271-
272-
def is_concat_op(op):
273-
return op.type in ("Concat", "ConcatV2", "ConcatV3")
274-
275-
276-
def is_tensor_array_gather_op(op):
277-
return op.type in ("TensorArrayGatherV2", "TensorArrayGatherV3")
278-
279-
280-
def is_tensor_array_write_op(op):
281-
return op.type in ("TensorArrayWriteV2", "TensorArrayWriteV3")
282-
283-
284-
def is_tensor_array_op(op):
285-
return op.type in ("TensorArrayV2", "TensorArrayV3")
286-
287-
288-
def is_loopcond_op(op):
289-
return op.type == "LoopCond"
290-
291-
292-
def is_select_op(op):
293-
return op.type == "Select"
294-
295-
296-
def is_slice_op(op):
297-
return op.type == "Slice"

tf2onnx/rewriter/unit_rnn_rewriter_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
from tf2onnx.graph_builder import GraphBuilder
1414
from tf2onnx.rewriter.loop_rewriter_base import LoopRewriterBase, Context
1515
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT, get_pattern, \
16-
get_rnn_scope_name, parse_rnn_loop, is_select_op, is_tensor_array_write_op, \
17-
seq_len_pattern
16+
get_rnn_scope_name, parse_rnn_loop, seq_len_pattern
17+
from tf2onnx.utils import is_select_op, is_tensor_array_write_op
1818
from tf2onnx.graph_matcher import GraphMatcher
1919

2020

21-
2221
logger = logging.getLogger(__name__)
2322

2423

tf2onnx/shape_inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import numpy as np
1313
from onnx import onnx_pb
1414
from tf2onnx import utils
15-
from tf2onnx.rewriter import rnn_utils
1615

1716
# pylint: disable=logging-not-lazy,missing-docstring,consider-swap-variables
1817

@@ -215,7 +214,7 @@ def infer_input_shapes(g, node):
215214
def infer_output_shapes_with_partial_inputs(g, node):
216215
# output shape of concat op: only the dim val of concatenated dim will be changed
217216
# so only partial(at least one) input shapes need to be known to infer output shape of concat node
218-
if rnn_utils.is_concat_op(node):
217+
if utils.is_concat_op(node):
219218
data_inputs = node.input[:-1]
220219
input_shapes = [g.get_shape(node) for node in data_inputs]
221220
input_shapes = [shape for shape in input_shapes if shape is not None]

tf2onnx/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,3 +493,35 @@ def get_url(url, path, max_retries=5):
493493

494494
with open(path, "wb") as f:
495495
f.write(response.content)
496+
497+
498+
def is_reverse_op(op):
499+
return op.type in ("ReverseV2", "ReverseSequence")
500+
501+
502+
def is_concat_op(op):
503+
return op.type in ("Concat", "ConcatV2", "ConcatV3")
504+
505+
506+
def is_tensor_array_gather_op(op):
507+
return op.type in ("TensorArrayGatherV2", "TensorArrayGatherV3")
508+
509+
510+
def is_tensor_array_write_op(op):
511+
return op.type in ("TensorArrayWriteV2", "TensorArrayWriteV3")
512+
513+
514+
def is_tensor_array_op(op):
515+
return op.type in ("TensorArrayV2", "TensorArrayV3")
516+
517+
518+
def is_loopcond_op(op):
519+
return op.type == "LoopCond"
520+
521+
522+
def is_select_op(op):
523+
return op.type == "Select"
524+
525+
526+
def is_slice_op(op):
527+
return op.type == "Slice"

0 commit comments

Comments
 (0)