Skip to content

Commit a7d3412

Browse files
pre-commit-ci[bot]iyilmaz24
authored andcommitted
test: merge changes from formatting bot
1 parent d9b5327 commit a7d3412

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed
Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import multiprocessing as mp
22
import os
3+
from collections.abc import Iterator
34
from queue import Queue
4-
from typing import Iterator
55

66
import numpy as np
7+
from torch.utils.data import DataLoader, IterableDataset
8+
79
from lightning import Trainer
810
from lightning.pytorch.demos.boring_classes import BoringModel
9-
from torch.utils.data import DataLoader, IterableDataset
11+
1012

1113
class QueueDataset(IterableDataset):
1214
def __init__(self, queue: Queue) -> None:
@@ -18,13 +20,15 @@ def __iter__(self) -> Iterator:
1820
tensor, _ = self.queue.get(timeout=5)
1921
yield tensor
2022

23+
2124
def create_queue():
2225
q = mp.Queue()
2326
arr = np.random.random([1, 32]).astype(np.float32)
2427
for ind in range(10):
2528
q.put((arr, ind))
2629
return q
2730

31+
2832
def train_model(queue, maxEpochs, ckptPath):
2933
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True)
3034
trainer = Trainer(max_epochs=maxEpochs, enable_progress_bar=False, devices=1)
@@ -35,21 +39,20 @@ def train_model(queue, maxEpochs, ckptPath):
3539
trainer.save_checkpoint(ckptPath)
3640
return trainer
3741

42+
3843
def test_training():
39-
"""
40-
Test that reproduces issue in calling iter twice on a queue-based
41-
IterableDataset leads to Queue Empty errors when resuming from a checkpoint.
42-
"""
44+
"""Test that reproduces issue in calling iter twice on a queue-based IterableDataset leads to Queue Empty errors
45+
when resuming from a checkpoint."""
4346
queue = create_queue()
4447

4548
ckpt_path = "model.ckpt"
4649
trainer = train_model(queue, 1, ckpt_path)
4750
assert trainer is not None
48-
51+
4952
assert os.path.exists(ckpt_path), "Checkpoint file wasn't created"
50-
53+
5154
ckpt_size = os.path.getsize(ckpt_path)
5255
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
53-
56+
5457
trainer = train_model(queue, 2, ckpt_path)
5558
assert trainer is not None

0 commit comments

Comments
 (0)