Skip to content

Commit 4195db0

Browse files
philgzlbhimrazy
andauthored
Fix ParallelStreamingDataset resume (Lightning-AI#761)
Co-authored-by: Bhimraj Yadav <bhimrajyadav977@gmail.com>
1 parent 825586b commit 4195db0

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/litdata/streaming/dataloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,7 @@ def __iter__(self) -> Any:
649649
# For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not
650650
# want to restart at index 0 at every epoch. So we set them in restore state.
651651
self.load_state_dict(self.state_dict())
652+
self.restore = False
652653
else:
653654
self._latest_worker_idx = 0
654655
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))

tests/streaming/test_parallel.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,9 +871,25 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res
871871
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
872872
if i == break_at:
873873
break
874+
expected_3 = [
875+
[torch.tensor([4]), torch.tensor([4])],
876+
[torch.tensor([9]), torch.tensor([9])],
877+
[torch.tensor([0]), torch.tensor([0])],
878+
[torch.tensor([5]), torch.tensor([5])],
879+
]
880+
for i, batch in enumerate(dloader):
881+
if not shuffle:
882+
assert all(
883+
torch.equal(x, y)
884+
for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i])
885+
)
886+
elif not resume and length is not None:
887+
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
888+
if i == break_at:
889+
break
874890

875891

876-
@pytest.mark.parametrize("length", [None, 6])
892+
@pytest.mark.parametrize("length", [None, 5])
877893
@pytest.mark.parametrize("resume", [False, True])
878894
@pytest.mark.parametrize("shuffle", [False, True])
879895
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI")
@@ -888,26 +904,39 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re
888904
[torch.tensor([1]), torch.tensor([1])],
889905
[torch.tensor([3]), torch.tensor([3])],
890906
[torch.tensor([0]), torch.tensor([0])],
891-
[torch.tensor([2]), torch.tensor([2])],
892907
]
893908
batches_1 = []
894909
for i, batch in enumerate(dloader):
895910
if not shuffle:
896911
assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i]))
897912
batches_1.append(batch)
898913
expected_2 = [
914+
[torch.tensor([1]), torch.tensor([1])],
915+
[torch.tensor([2]), torch.tensor([2])],
916+
[torch.tensor([0]), torch.tensor([0])],
917+
[torch.tensor([3]), torch.tensor([3])],
918+
[torch.tensor([1]), torch.tensor([1])],
919+
]
920+
for i, batch in enumerate(dloader):
921+
if not shuffle:
922+
assert all(
923+
torch.equal(x, y)
924+
for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i])
925+
)
926+
elif not resume and length is not None:
927+
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
928+
expected_3 = [
899929
[torch.tensor([1]), torch.tensor([1])],
900930
[torch.tensor([3]), torch.tensor([3])],
901931
[torch.tensor([0]), torch.tensor([0])],
902932
[torch.tensor([2]), torch.tensor([2])],
903933
[torch.tensor([1]), torch.tensor([1])],
904-
[torch.tensor([3]), torch.tensor([3])],
905934
]
906935
for i, batch in enumerate(dloader):
907936
if not shuffle:
908937
assert all(
909938
torch.equal(x, y)
910-
for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i])
939+
for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i])
911940
)
912941
elif not resume and length is not None:
913942
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))

0 commit comments

Comments
 (0)