Skip to content

Commit cfad591

Browse files
Fix implementation of skip_tf_versions (#1653)
* Fix implementation of skip_tf_versions Signed-off-by: Tom Wildenhain <[email protected]> * Fix LSTM test Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 2883580 commit cfad591

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

tests/common.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,18 @@ def check_tf_min_version(min_required_version, message=""):
286286

287287

288288
def skip_tf_versions(excluded_versions, message=""):
289-
""" Skip if tf_version SEMANTICALLY matches any of excluded_versions. """
289+
""" Skip if tf_version matches any of excluded_versions. """
290+
if not isinstance(excluded_versions, list):
291+
excluded_versions = [excluded_versions]
290292
config = get_test_config()
291293
condition = False
292294
reason = _append_message("conversion excludes tf {}".format(excluded_versions), message)
293295

294-
current_tokens = str(config.tf_version).split('.')
295296
for excluded_version in excluded_versions:
296-
exclude_tokens = excluded_version.split('.')
297-
# assume len(exclude_tokens) <= len(current_tokens)
298-
for i, exclude in enumerate(exclude_tokens):
299-
if not current_tokens[i] == exclude:
300-
break
301-
condition = True
297+
# tf version with same specificity as excluded_version
298+
tf_version = '.'.join(str(config.tf_version).split('.')[:excluded_version.count('.') + 1])
299+
if excluded_version == tf_version:
300+
condition = True
302301

303302
return unittest.skipIf(condition, reason)
304303

tests/test_lstm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,8 @@ def func(x):
674674
feed_dict = {"input_1:0": x_val}
675675
input_names_with_port = ["input_1:0"]
676676
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
677-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
677+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
678+
require_lstm_count=2)
678679

679680
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
680681
@skip_tf_versions("2.1", "Bug in TF 2.1")
@@ -721,7 +722,8 @@ def func(x, y1, y2):
721722
feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val}
722723
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
723724
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
724-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
725+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
726+
require_lstm_count=2)
725727

726728

727729
if __name__ == '__main__':

0 commit comments

Comments
 (0)