Skip to content

Commit e908330

Browse files
ejguanfacebook-github-bot
authored andcommitted
Make SeqRS invokes checkpoint/restore only when they are available (#1008)
Summary: Per title Pull Request resolved: #1008 Reviewed By: mingyuzh Differential Revision: D43243392 Pulled By: ejguan fbshipit-source-id: cad70b7a859481bc293a95e45979c98b589bb90a
1 parent 657b1dd commit e908330

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

torchdata/dataloader2/reading_service.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import multiprocessing as py_mp
88
import queue
9+
import warnings
910

1011
from abc import ABC, abstractmethod
1112
from datetime import timedelta
@@ -534,13 +535,20 @@ def finalize_iteration(self) -> None:
534535
def checkpoint(self) -> bytes:
535536
states = []
536537
for rs in self.reading_services:
537-
states.append(rs.checkpoint())
538+
if hasattr(rs, "checkpoint") and callable(rs.checkpoint):
539+
states.append(rs.checkpoint())
540+
else:
541+
warnings.warn(f"{rs} doesn't support `checkpoint`, skipping...")
542+
states.append(b"")
538543
return b"\n".join(states)
539544

540545
# Sequential Order, to align with initialize
541546
def restore(self, datapipe, serialized_state: bytes) -> DataPipe:
542547
states = serialized_state.split(b"\n")
543548
assert len(states) == len(self.reading_services)
544549
for rs, state in zip(self.reading_services, states):
545-
datapipe = rs.restore(datapipe, state)
550+
if hasattr(rs, "restore") and callable(rs.restore):
551+
datapipe = rs.restore(datapipe, state)
552+
else:
553+
warnings.warn(f"{rs} doesn't support `restore` from state, skipping...")
546554
return datapipe

0 commit comments

Comments
 (0)