Skip to content

Commit 48acea8

Browse files
authored
Support TensorFlow 2.7 (#902)
* Support TensorFlow 2.7 * Fix formatting * Skip TensorFlow Lite test case with TensorFlow 2.7
1 parent 8b548d4 commit 48acea8

File tree

7 files changed

+32
-16
lines changed

7 files changed

+32
-16
lines changed

.github/workflows/ci.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ jobs:
1717
steps:
1818
- uses: actions/checkout@v2
1919

20-
- name: Set up Python 3.6
20+
- name: Set up Python 3.8
2121
uses: actions/setup-python@v2
2222
with:
23-
python-version: 3.6
23+
python-version: 3.8
2424

2525
- name: Install dependencies
2626
run: |
@@ -44,15 +44,15 @@ jobs:
4444
runs-on: ubuntu-latest
4545
strategy:
4646
matrix:
47-
tensorflow: [2.4, 2.5, 2.6]
47+
tensorflow: [2.4, 2.5, 2.6, 2.7]
4848

4949
steps:
5050
- uses: actions/checkout@v2
5151

52-
- name: Set up Python 3.6
52+
- name: Set up Python 3.8
5353
uses: actions/setup-python@v2
5454
with:
55-
python-version: 3.6
55+
python-version: 3.8
5656

5757
- name: Install dependencies
5858
run: |
@@ -71,7 +71,7 @@ jobs:
7171
pytest --cov=opennmt --cov-report xml opennmt/tests
7272
7373
- name: Upload coverage report
74-
if: matrix.tensorflow == '2.6'
74+
if: matrix.tensorflow == '2.7'
7575
uses: codecov/codecov-action@v2
7676

7777

@@ -82,10 +82,10 @@ jobs:
8282
steps:
8383
- uses: actions/checkout@v2
8484

85-
- name: Set up Python 3.6
85+
- name: Set up Python 3.8
8686
uses: actions/setup-python@v2
8787
with:
88-
python-version: 3.6
88+
python-version: 3.8
8989

9090
- name: Install dependencies
9191
run: |
@@ -112,10 +112,10 @@ jobs:
112112
with:
113113
persist-credentials: false
114114

115-
- name: Set up Python 3.6
115+
- name: Set up Python 3.8
116116
uses: actions/setup-python@v2
117117
with:
118-
python-version: 3.6
118+
python-version: 3.8
119119

120120
- name: Install dependencies
121121
run: |

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ OpenNMT-tf also implements most of the techniques commonly used to train and eva
9797
OpenNMT-tf requires:
9898

9999
* Python 3.6 or above
100-
* TensorFlow 2.4, 2.5, or 2.6
100+
* TensorFlow 2.4, 2.5, 2.6, or 2.7
101101

102102
We recommend installing it with `pip`:
103103

docs/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
OpenNMT-tf requires:
66

77
* Python 3.6 or above
8-
* TensorFlow 2.4, 2.5, or 2.6
8+
* TensorFlow 2.4, 2.5, 2.6, or 2.7
99

1010
For GPU support, please read the [TensorFlow documentation](https://www.tensorflow.org/install/gpu) for additional software and hardware requirements.
1111

opennmt/data/dataset.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,11 @@ def _transform(dataset):
254254

255255
# Take all original batches and the number of extra batches required.
256256
dataset = dataset.enumerate()
257-
dataset = dataset.apply(tf.data.experimental.take_while(_continue_iter))
257+
# TODO: clean this API when TensorFlow requirement is updated to >=2.6.
258+
if compat.tf_supports("data.Dataset.take_while"):
259+
dataset = dataset.take_while(_continue_iter)
260+
else:
261+
dataset = dataset.apply(tf.data.experimental.take_while(_continue_iter))
258262
return dataset.map(_retrieve_element) # Retrieve the element only.
259263

260264
return _transform
@@ -588,7 +592,13 @@ def _make_weighted_dataset(datasets, weights):
588592
datasets = [dataset.shard(num_shards, shard_index) for dataset in datasets]
589593
weights = normalize_weights(datasets, weights=weights, sizes=dataset_size)
590594
datasets = [dataset.repeat() for dataset in datasets]
591-
dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights)
595+
# TODO: clean this API when TensorFlow requirement is updated to >=2.7.
596+
if compat.tf_supports("data.Dataset.sample_from_datasets"):
597+
dataset = tf.data.Dataset.sample_from_datasets(datasets, weights=weights)
598+
else:
599+
dataset = tf.data.experimental.sample_from_datasets(
600+
datasets, weights=weights
601+
)
592602
if shuffle_buffer_size is not None and shuffle_buffer_size != 0:
593603
if shuffle_buffer_size < 0:
594604
raise ValueError(

opennmt/tests/tflite_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ def testTFLiteOutput(self, model, params):
118118
version.parse(tf.__version__) < version.parse("2.5.0"),
119119
reason="TensorFlow Lite exporting requires TensorFlow 2.5+",
120120
)
121+
@pytest.mark.skipif(
122+
version.parse("2.7.0")
123+
<= version.parse(tf.__version__)
124+
< version.parse("2.8.0"),
125+
reason="Test case failing with TensorFlow 2.7",
126+
)
121127
def testTFLiteInterpreter(self, model, params=None, quantization=None):
122128
if params is None:
123129
params = {}

opennmt/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__version__ = "2.22.0"
44

55
INCLUSIVE_MIN_TF_VERSION = "2.4.0"
6-
EXCLUSIVE_MAX_TF_VERSION = "2.7.0"
6+
EXCLUSIVE_MAX_TF_VERSION = "2.8.0"
77

88

99
def _check_tf_version():

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_project_version():
6969
"pyyaml>=5.3,<7",
7070
"rouge>=1.0,<2",
7171
"sacrebleu>=1.5.0,<2.1",
72-
"tensorflow-addons>=0.14,<0.15",
72+
"tensorflow-addons>=0.14,<0.16",
7373
],
7474
extras_require={
7575
"tensorflow": [

0 commit comments

Comments
 (0)