Skip to content

Commit a971a6a

Browse files
abhinavarorapre-commit-ci[bot]awaelchli
authored andcommitted
Remove references to torchtext.legacy from PyTorch Lightning (#10724)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 114ac41 commit a971a6a

File tree

7 files changed

+123
-45
lines changed

7 files changed

+123
-45
lines changed

pytorch_lightning/utilities/apply_func.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
import numpy as np
2424
import torch
2525

26-
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE
26+
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
27+
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
2728

28-
if _TORCHTEXT_AVAILABLE:
29+
if _TORCHTEXT_LEGACY:
2930
if _compare_version("torchtext", operator.ge, "0.9.0"):
3031
from torchtext.legacy.data import Batch
3132
else:
@@ -260,8 +261,13 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
260261

261262
def batch_to(data: Any) -> Any:
262263
# try to move torchtext data first
263-
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
264-
264+
if _TORCHTEXT_LEGACY and isinstance(data, Batch):
265+
# TODO: also remove the torchtext dependency with Lightning 1.8
266+
rank_zero_deprecation(
267+
"The `torchtext.legacy.Batch` object is deprecated and Lightning will remove support for it in v1.8."
268+
" We recommend you to migrate away from Batch by following the TorchText README:"
269+
" https://github.com/pytorch/text#bc-breaking-legacy"
270+
)
265271
# Shallow copy because each Batch has a reference to Dataset which contains all examples
266272
device_data = copy(data)
267273
for field, field_value in data.dataset.fields.items():
@@ -281,7 +287,7 @@ def batch_to(data: Any) -> Any:
281287
# user wrongly implemented the `TransferableDataType` and forgot to return `self`.
282288
return data
283289

284-
dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
290+
dtype = (TransferableDataType, Batch) if _TORCHTEXT_LEGACY else TransferableDataType
285291
return apply_to_collection(batch, dtype=dtype, function=batch_to)
286292

287293

pytorch_lightning/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
118118
_RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2")
119119
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
120120
_TORCHTEXT_AVAILABLE = _package_available("torchtext")
121+
_TORCHTEXT_LEGACY: bool = _TORCHTEXT_AVAILABLE and _compare_version("torchtext", operator.lt, "0.11.0")
121122
_TORCHVISION_AVAILABLE = _package_available("torchvision")
122123
_XLA_AVAILABLE: bool = _package_available("torch_xla")
123124

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Test deprecated functionality which will be removed in v1.8.0."""
15+
import pytest
16+
import torch
17+
18+
from pytorch_lightning.utilities.apply_func import move_data_to_device
19+
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
20+
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
21+
22+
23+
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
24+
def test_v1_8_0_deprecated_torchtext_batch():
25+
26+
with pytest.deprecated_call(match="is deprecated and Lightning will remove support for it in v1.8"):
27+
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3)
28+
batch = next(iter(data_iterator))
29+
_ = move_data_to_device(batch=batch, device=torch.device("cpu"))

tests/helpers/imports.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
1+
import operator
22

3-
if _TORCH_GREATER_EQUAL_1_8:
4-
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField
3+
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
4+
5+
if _TORCHTEXT_LEGACY:
6+
if _compare_version("torchtext", operator.ge, "0.9.0"):
7+
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField
8+
else:
9+
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField
510
else:
6-
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
11+
Batch = type(None)
12+
Dataset = type(None)
13+
Example = type(None)
14+
Field = type(None)
15+
Iterator = type(None)
16+
LabelField = type(None)

tests/helpers/torchtext_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import random
15+
import string
16+
17+
from tests.helpers.imports import Dataset, Example, Field, Iterator
18+
19+
20+
def _generate_random_string(length: int = 10):
21+
return "".join(random.choices(string.ascii_letters, k=length))
22+
23+
24+
def get_dummy_torchtext_data_iterator(num_samples: int, batch_size: int, include_lengths: bool = False):
25+
text_field = Field(
26+
sequential=True,
27+
pad_first=False, # nosec
28+
init_token="<s>",
29+
eos_token="</s>", # nosec
30+
include_lengths=include_lengths,
31+
) # nosec
32+
33+
dataset = Dataset(
34+
[
35+
Example.fromdict({"text": _generate_random_string()}, {"text": ("text", text_field)})
36+
for _ in range(num_samples)
37+
],
38+
{"text": text_field},
39+
)
40+
text_field.build_vocab(dataset)
41+
42+
iterator = Iterator(
43+
dataset,
44+
batch_size=batch_size,
45+
sort_key=None,
46+
device=None,
47+
batch_size_fn=None,
48+
train=True,
49+
repeat=False,
50+
shuffle=None,
51+
sort=None,
52+
sort_within_batch=None,
53+
)
54+
return iterator, text_field

tests/models/test_gpu.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
2727
from pytorch_lightning.utilities import device_parser
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
29-
from pytorch_lightning.utilities.imports import _compare_version
29+
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
3030
from tests.helpers import BoringModel
3131
from tests.helpers.datamodules import ClassifDataModule
3232
from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField
@@ -309,6 +309,9 @@ def to(self, *args, **kwargs):
309309
assert batch.a.type() == "torch.cuda.FloatTensor"
310310

311311
# torchtext.data.Batch
312+
if not _TORCHTEXT_LEGACY:
313+
return
314+
312315
samples = [
313316
{"text": "PyTorch Lightning is awesome!", "label": 0},
314317
{"text": "Please make it work with torchtext", "label": 1},
@@ -326,7 +329,8 @@ def to(self, *args, **kwargs):
326329
label_field.build_vocab(dataset)
327330

328331
batch = Batch(data=examples, dataset=dataset)
329-
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
332+
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
333+
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
330334

331335
assert batch.text.type() == "torch.cuda.LongTensor"
332336
assert batch.label.type() == "torch.cuda.LongTensor"

tests/utilities/test_apply_func_torchtext.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,49 +15,22 @@
1515
import torch
1616

1717
from pytorch_lightning.utilities.apply_func import move_data_to_device
18-
from tests.helpers.imports import Dataset, Example, Field, Iterator
18+
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
1919
from tests.helpers.runif import RunIf
20-
21-
22-
def _get_torchtext_data_iterator(include_lengths=False):
23-
text_field = Field(
24-
sequential=True,
25-
pad_first=False, # nosec
26-
init_token="<s>",
27-
eos_token="</s>", # nosec
28-
include_lengths=include_lengths,
29-
) # nosec
30-
31-
example1 = Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)})
32-
example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
33-
example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})
34-
35-
dataset = Dataset([example1, example2, example3], {"text": text_field})
36-
text_field.build_vocab(dataset)
37-
38-
iterator = Iterator(
39-
dataset,
40-
batch_size=3,
41-
sort_key=None,
42-
device=None,
43-
batch_size_fn=None,
44-
train=True,
45-
repeat=False,
46-
shuffle=None,
47-
sort=None,
48-
sort_within_batch=None,
49-
)
50-
return iterator, text_field
20+
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
5121

5222

5323
@pytest.mark.parametrize("include_lengths", [False, True])
5424
@pytest.mark.parametrize("device", [torch.device("cuda", 0)])
25+
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
5526
@RunIf(min_gpus=1)
5627
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device):
57-
data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths)
28+
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3, include_lengths=include_lengths)
5829
data_iter = iter(data_iterator)
5930
batch = next(data_iter)
60-
batch_on_device = move_data_to_device(batch, device)
31+
32+
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
33+
batch_on_device = move_data_to_device(batch, device)
6134

6235
if include_lengths:
6336
# tensor with data
@@ -69,5 +42,6 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de
6942

7043

7144
@pytest.mark.parametrize("include_lengths", [False, True])
45+
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
7246
def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths):
7347
test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device("cpu"))

0 commit comments

Comments
 (0)