Skip to content

Commit d9b5327

Browse files
committed
test: remove boilerplate and adjust parameters
- removed unnecessary “if __name__ == '__main__'” block - changed second Trainer's max_epochs from 1 to 2 - reduced queue.get() timeout to 5 seconds for faster test execution - deleted accidental extra trainer.fit() call
1 parent 696257b commit d9b5327

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

tests/tests_pytorch/loops/test_trainer_iterable_dataset_double_iter.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, queue: Queue) -> None:
1515

1616
def __iter__(self) -> Iterator:
1717
for _ in range(5):
18-
tensor, _ = self.queue.get(timeout=10)
18+
tensor, _ = self.queue.get(timeout=5)
1919
yield tensor
2020

2121
def create_queue():
@@ -28,7 +28,6 @@ def create_queue():
2828
def train_model(queue, maxEpochs, ckptPath):
2929
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True)
3030
trainer = Trainer(max_epochs=maxEpochs, enable_progress_bar=False, devices=1)
31-
trainer.fit(BoringModel(), dataloader)
3231
if os.path.exists(ckptPath):
3332
trainer.fit(BoringModel(), dataloader, ckpt_path=ckptPath)
3433
else:
@@ -52,8 +51,5 @@ def test_training():
5251
ckpt_size = os.path.getsize(ckpt_path)
5352
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
5453

55-
trainer = train_model(queue, 1, ckpt_path)
54+
trainer = train_model(queue, 2, ckpt_path)
5655
assert trainer is not None
57-
58-
if __name__ == "__main__":
59-
test_training()

0 commit comments

Comments
 (0)