15
15
import torch
16
16
17
17
from pytorch_lightning .utilities .apply_func import move_data_to_device
18
- from tests . helpers .imports import Dataset , Example , Field , Iterator
18
+ from pytorch_lightning . utilities .imports import _TORCHTEXT_LEGACY
19
19
from tests .helpers .runif import RunIf
20
-
21
-
22
- def _get_torchtext_data_iterator (include_lengths = False ):
23
- text_field = Field (
24
- sequential = True ,
25
- pad_first = False , # nosec
26
- init_token = "<s>" ,
27
- eos_token = "</s>" , # nosec
28
- include_lengths = include_lengths ,
29
- ) # nosec
30
-
31
- example1 = Example .fromdict ({"text" : "a b c a c" }, {"text" : ("text" , text_field )})
32
- example2 = Example .fromdict ({"text" : "b c a a" }, {"text" : ("text" , text_field )})
33
- example3 = Example .fromdict ({"text" : "c b a" }, {"text" : ("text" , text_field )})
34
-
35
- dataset = Dataset ([example1 , example2 , example3 ], {"text" : text_field })
36
- text_field .build_vocab (dataset )
37
-
38
- iterator = Iterator (
39
- dataset ,
40
- batch_size = 3 ,
41
- sort_key = None ,
42
- device = None ,
43
- batch_size_fn = None ,
44
- train = True ,
45
- repeat = False ,
46
- shuffle = None ,
47
- sort = None ,
48
- sort_within_batch = None ,
49
- )
50
- return iterator , text_field
20
+ from tests .helpers .torchtext_utils import get_dummy_torchtext_data_iterator
51
21
52
22
53
23
@pytest .mark .parametrize ("include_lengths" , [False , True ])
54
24
@pytest .mark .parametrize ("device" , [torch .device ("cuda" , 0 )])
25
+ @pytest .mark .skipif (not _TORCHTEXT_LEGACY , reason = "torchtext.legacy is deprecated." )
55
26
@RunIf (min_gpus = 1 )
56
27
def test_batch_move_data_to_device_torchtext_include_lengths (include_lengths , device ):
57
- data_iterator , _ = _get_torchtext_data_iterator ( include_lengths = include_lengths )
28
+ data_iterator , _ = get_dummy_torchtext_data_iterator ( num_samples = 3 , batch_size = 3 , include_lengths = include_lengths )
58
29
data_iter = iter (data_iterator )
59
30
batch = next (data_iter )
60
- batch_on_device = move_data_to_device (batch , device )
31
+
32
+ with pytest .deprecated_call (match = "The `torchtext.legacy.Batch` object is deprecated" ):
33
+ batch_on_device = move_data_to_device (batch , device )
61
34
62
35
if include_lengths :
63
36
# tensor with data
@@ -69,5 +42,6 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de
69
42
70
43
71
44
@pytest .mark .parametrize ("include_lengths" , [False , True ])
45
+ @pytest .mark .skipif (not _TORCHTEXT_LEGACY , reason = "torchtext.legacy is deprecated." )
72
46
def test_batch_move_data_to_device_torchtext_include_lengths_cpu (include_lengths ):
73
47
test_batch_move_data_to_device_torchtext_include_lengths (include_lengths , torch .device ("cpu" ))
0 commit comments