|
6 | 6 |
|
7 | 7 | import multiprocessing as py_mp
|
8 | 8 | import queue
|
| 9 | +import warnings |
9 | 10 |
|
10 | 11 | from abc import ABC, abstractmethod
|
11 | 12 | from datetime import timedelta
|
@@ -534,13 +535,20 @@ def finalize_iteration(self) -> None:
|
534 | 535 | def checkpoint(self) -> bytes:
|
535 | 536 | states = []
|
536 | 537 | for rs in self.reading_services:
|
537 |
| - states.append(rs.checkpoint()) |
| 538 | + if hasattr(rs, "checkpoint") and callable(rs.checkpoint): |
| 539 | + states.append(rs.checkpoint()) |
| 540 | + else: |
| 541 | + warnings.warn(f"{rs} doesn't support `checkpoint`, skipping...") |
| 542 | + states.append(b"") |
538 | 543 | return b"\n".join(states)
|
539 | 544 |
|
540 | 545 | # Sequential Order, to align with initialize
|
541 | 546 | def restore(self, datapipe, serialized_state: bytes) -> DataPipe:
|
542 | 547 | states = serialized_state.split(b"\n")
|
543 | 548 | assert len(states) == len(self.reading_services)
|
544 | 549 | for rs, state in zip(self.reading_services, states):
|
545 |
| - datapipe = rs.restore(datapipe, state) |
| 550 | + if hasattr(rs, "restore") and callable(rs.restore): |
| 551 | + datapipe = rs.restore(datapipe, state) |
| 552 | + else: |
| 553 | + warnings.warn(f"{rs} doesn't support `restore` from state, skipping...") |
546 | 554 | return datapipe
|
0 commit comments