11import multiprocessing as mp
22import os
3+ from collections .abc import Iterator
34from queue import Queue
4- from typing import Iterator
55
66import numpy as np
7+ from torch .utils .data import DataLoader , IterableDataset
8+
79from lightning import Trainer
810from lightning .pytorch .demos .boring_classes import BoringModel
9- from torch . utils . data import DataLoader , IterableDataset
11+
1012
1113class QueueDataset (IterableDataset ):
1214 def __init__ (self , queue : Queue ) -> None :
@@ -18,13 +20,15 @@ def __iter__(self) -> Iterator:
1820 tensor , _ = self .queue .get (timeout = 10 )
1921 yield tensor
2022
23+
2124def create_queue ():
2225 q = mp .Queue ()
2326 arr = np .random .random ([1 , 32 ]).astype (np .float32 )
2427 for ind in range (10 ):
2528 q .put ((arr , ind ))
2629 return q
2730
31+
2832def train_model (queue , maxEpochs , ckptPath ):
2933 dataloader = DataLoader (QueueDataset (queue ), num_workers = 1 , batch_size = None , persistent_workers = True )
3034 trainer = Trainer (max_epochs = maxEpochs , enable_progress_bar = False , devices = 1 )
@@ -36,24 +40,24 @@ def train_model(queue, maxEpochs, ckptPath):
3640 trainer .save_checkpoint (ckptPath )
3741 return trainer
3842
43+
3944def test_training ():
40- """
41- Test that reproduces issue in calling iter twice on a queue-based
42- IterableDataset leads to Queue Empty errors when resuming from a checkpoint.
43- """
45+ """Test that reproduces issue in calling iter twice on a queue-based IterableDataset leads to Queue Empty errors
46+ when resuming from a checkpoint."""
4447 queue = create_queue ()
4548
4649 ckpt_path = "model.ckpt"
4750 trainer = train_model (queue , 1 , ckpt_path )
4851 assert trainer is not None
49-
52+
5053 assert os .path .exists (ckpt_path ), "Checkpoint file wasn't created"
51-
54+
5255 ckpt_size = os .path .getsize (ckpt_path )
5356 assert ckpt_size > 0 , f"Checkpoint file is empty (size: { ckpt_size } bytes)"
54-
57+
5558 trainer = train_model (queue , 1 , ckpt_path )
5659 assert trainer is not None
5760
61+
5862if __name__ == "__main__" :
5963 test_training ()
0 commit comments