1515import torch
1616
1717from pytorch_lightning .utilities .apply_func import move_data_to_device
18- from tests . helpers .imports import Dataset , Example , Field , Iterator
18+ from pytorch_lightning . utilities .imports import _TORCHTEXT_LEGACY
1919from tests .helpers .runif import RunIf
20-
21-
22- def _get_torchtext_data_iterator (include_lengths = False ):
23- text_field = Field (
24- sequential = True ,
25- pad_first = False , # nosec
26- init_token = "<s>" ,
27- eos_token = "</s>" , # nosec
28- include_lengths = include_lengths ,
29- ) # nosec
30-
31- example1 = Example .fromdict ({"text" : "a b c a c" }, {"text" : ("text" , text_field )})
32- example2 = Example .fromdict ({"text" : "b c a a" }, {"text" : ("text" , text_field )})
33- example3 = Example .fromdict ({"text" : "c b a" }, {"text" : ("text" , text_field )})
34-
35- dataset = Dataset ([example1 , example2 , example3 ], {"text" : text_field })
36- text_field .build_vocab (dataset )
37-
38- iterator = Iterator (
39- dataset ,
40- batch_size = 3 ,
41- sort_key = None ,
42- device = None ,
43- batch_size_fn = None ,
44- train = True ,
45- repeat = False ,
46- shuffle = None ,
47- sort = None ,
48- sort_within_batch = None ,
49- )
50- return iterator , text_field
20+ from tests .helpers .torchtext_utils import get_dummy_torchtext_data_iterator
5121
5222
5323@pytest .mark .parametrize ("include_lengths" , [False , True ])
5424@pytest .mark .parametrize ("device" , [torch .device ("cuda" , 0 )])
25+ @pytest .mark .skipif (not _TORCHTEXT_LEGACY , reason = "torchtext.legacy is deprecated." )
5526@RunIf (min_gpus = 1 )
5627def test_batch_move_data_to_device_torchtext_include_lengths (include_lengths , device ):
57- data_iterator , _ = _get_torchtext_data_iterator ( include_lengths = include_lengths )
28+ data_iterator , _ = get_dummy_torchtext_data_iterator ( num_samples = 3 , batch_size = 3 , include_lengths = include_lengths )
5829 data_iter = iter (data_iterator )
5930 batch = next (data_iter )
60- batch_on_device = move_data_to_device (batch , device )
31+
32+ with pytest .deprecated_call (match = "The `torchtext.legacy.Batch` object is deprecated" ):
33+ batch_on_device = move_data_to_device (batch , device )
6134
6235 if include_lengths :
6336 # tensor with data
@@ -69,5 +42,6 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de
6942
7043
7144@pytest .mark .parametrize ("include_lengths" , [False , True ])
45+ @pytest .mark .skipif (not _TORCHTEXT_LEGACY , reason = "torchtext.legacy is deprecated." )
7246def test_batch_move_data_to_device_torchtext_include_lengths_cpu (include_lengths ):
7347 test_batch_move_data_to_device_torchtext_include_lengths (include_lengths , torch .device ("cpu" ))
0 commit comments