Skip to content

Commit e4c3409

Browse files
committed
fix unit test
1 parent 04d0b80 commit e4c3409

9 files changed

+28
-32
lines changed

tests/test_cudnn_compatible_gru.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
1717

1818
if is_tf2():
19-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
20-
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
21-
bidirectional_dynamic_rnn = tf.compat.v1.nn.bidirectional_dynamic_rnn
19+
pass
2220
else:
2321
GRUBlockCell = tf.contrib.rnn.GRUBlockCell
2422
MultiRNNCell = tf.contrib.rnn.MultiRNNCell

tests/test_custom_rnncell.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@
1616
# pylint: disable=abstract-method,arguments-differ
1717

1818
if is_tf2():
19-
BasicLSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "BasicLSTMCell", None)
20-
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
21-
GRUCell = getattr(tf.compat.v1.nn.rnn_cell, "GRUCell", None)
22-
RNNCell = getattr(tf.compat.v1.nn.rnn_cell, "RNNCell", None)
23-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
24-
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
25-
bidirectional_dynamic_rnn = tf.compat.v1.nn.bidirectional_dynamic_rnn
19+
pass
2620
else:
2721
LSTMBlockCell = tf.contrib.rnn.LSTMBlockCell
2822
LSTMCell = tf.nn.rnn_cell.LSTMCell

tests/test_gru.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@
3434

3535
if is_tf2():
3636
# There is no LSTMBlockCell in tf-2.x
37-
BasicLSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "BasicLSTMCell", None)
38-
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
39-
GRUCell = getattr(tf.compat.v1.nn.rnn_cell, "GRUCell", None)
40-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
37+
try:
38+
BasicLSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "BasicLSTMCell", None)
39+
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
40+
GRUCell = getattr(tf.compat.v1.nn.rnn_cell, "GRUCell", None)
41+
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
42+
except ImportError:
43+
pass
4144
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
4245
bidirectional_dynamic_rnn = tf.compat.v1.nn.bidirectional_dynamic_rnn
4346
else:

tests/test_grublock.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
# pylint: disable=invalid-name
1818

1919
if is_tf2():
20-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
20+
try:
21+
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
22+
except ImportError:
23+
pass
2124
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
2225
bidirectional_dynamic_rnn = tf.compat.v1.nn.bidirectional_dynamic_rnn
2326
else:

tests/test_lstm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121

2222
if is_tf2():
2323
# There is no LSTMBlockCell in tf-2.x
24-
BasicLSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "BasicLSTMCell", None)
25-
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
26-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
24+
try:
25+
BasicLSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "BasicLSTMCell", None)
26+
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
27+
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
28+
except ImportError:
29+
pass
2730
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
2831
bidirectional_dynamic_rnn = tf.compat.v1.nn.bidirectional_dynamic_rnn
2932
else:

tests/test_lstmblock.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
if is_tf2():
1818
# There is no LSTMBlockCell in tf-2.x
19-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
20-
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
21-
bidirectional_dynamic_rnn = tf.compat.v1.nn.bidirectional_dynamic_rnn
19+
pass
2220
else:
2321
LSTMBlockCell = tf.contrib.rnn.LSTMBlockCell
2422
MultiRNNCell = tf.contrib.rnn.MultiRNNCell

tests/test_seq2seq.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313
# pylint: disable=invalid-name
1414

1515
if is_tf2():
16-
BasicLSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "BasicLSTMCell", None)
17-
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
18-
RNNCell = getattr(tf.compat.v1.nn.rnn_cell, "RNNCell", None)
19-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
20-
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
21-
bidirectional_dynamic_rnn = tf.compat.v1.nn.bidirectional_dynamic_rnn
22-
LSTMStateTuple = getattr(tf.compat.v1.nn.rnn_cell, "LSTMStateTuple", None)
16+
pass
2317
else:
2418
LSTMCell = tf.contrib.rnn.LSTMCell
2519
LSTMBlockCell = tf.contrib.rnn.LSTMBlockCell

tests/test_stacked_lstm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
# pylint: disable=invalid-name
1717

1818
if is_tf2():
19-
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
20-
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
19+
try:
20+
LSTMCell = getattr(tf.compat.v1.nn.rnn_cell, "LSTMCell", None)
21+
MultiRNNCell = getattr(tf.compat.v1.nn.rnn_cell, "MultiRNNCell", None)
22+
except ImportError:
23+
pass
2124
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
2225
else:
2326
LSTMCell = tf.contrib.rnn.LSTMCell

tests/utils/setup_test_env.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ if [[ $TF_VERSION == 1.* ]]; then
2323
pip install numpy==1.19.0
2424
else
2525
if [[ "$TF_VERSION" != "2.13.0" && "$TF_VERSION" != "2.9.0" ]]; then
26-
echo "-- install-3 TF-KERAS ${{ inputs.tf_version }}"
26+
echo "-- install-3 TF-KERAS $TF_VERSION"
2727
pip install tf_keras==$TF_VERSION
2828
else
29-
echo "-- install-3 TF ${{ inputs.tf_version }}"
29+
echo "-- install-3 TF $TF_VERSION"
3030
fi
3131
fi
3232

0 commit comments

Comments
 (0)