diff --git a/test/stateful_dataloader/test_state_parser.py b/test/stateful_dataloader/test_state_parser.py new file mode 100644 index 000000000..2b0e426bf --- /dev/null +++ b/test/stateful_dataloader/test_state_parser.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from torch.testing._internal.common_utils import TestCase + +from torch.utils.data import Dataset, IterableDataset +from torchdata.stateful_dataloader import Stateful, StatefulDataLoader, StateParserUtil + + +class StatefulIterableDataset(IterableDataset, Stateful): + def __init__(self): + self.num_calls = 0 + + def __iter__(self): + return self + + def __next__(self): + self.num_calls += 1 + return self.num_calls + + def load_state_dict(self, state_dict): + self.num_calls = state_dict["num_calls"] + + def state_dict(self): + return {"num_calls": self.num_calls} + + +def identity(x): + return x + + +class TestIteratorDataset(TestCase): + def test_increasing_worker(self): + ds = StatefulIterableDataset() + dl = StatefulDataLoader(ds, num_workers=2, collate_fn=identity) + it = iter(dl) + next(it) + sd = dl.state_dict() + print(sd) + del dl + + parser = StateParserUtil(sd) + worker_states = parser.fetch_dataset_state() + worker_states[2] = {"num_calls": 2} + worker_states[3] = {"num_calls": 3} + parser.set_dataset_state(worker_states) + + # worker state doesn't equal num workers setting + with self.assertRaises(AssertionError): + parser.get_state_dict() + parser.set_num_workers(4) + + # last worker yielded id is greater than num workers + parser.set_last_worker_yielded_id(10) + with self.assertRaises(AssertionError): + parser.get_state_dict() + parser.set_last_worker_yielded_id(0) + + # load the modified state + new_sd = parser.get_state_dict() + print(new_sd) + dl = StatefulDataLoader(ds, num_workers=4, collate_fn=identity) + dl.load_state_dict(new_sd) + it = iter(dl) + values = [] + for _ in range(4): + values.extend(next(it)) + print(values) + self.assertEqual(values, [1, 3, 4, 2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchdata/stateful_dataloader/__init__.py b/torchdata/stateful_dataloader/__init__.py index 93d042ee7..52e8b2984 100644 --- a/torchdata/stateful_dataloader/__init__.py +++ b/torchdata/stateful_dataloader/__init__.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .state_parser import StateParserUtil from .stateful import Stateful from .stateful_dataloader import StatefulDataLoader -__all__ = ["Stateful", "StatefulDataLoader"] +__all__ = ["Stateful", "StatefulDataLoader", "StateParserUtil"] diff --git a/torchdata/stateful_dataloader/state_parser.py b/torchdata/stateful_dataloader/state_parser.py new file mode 100644 index 000000000..6d465d553 --- /dev/null +++ b/torchdata/stateful_dataloader/state_parser.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import logging +from typing import Any, Dict, Union + +logger = logging.getLogger(__name__) + + +class StateParserUtil: + """ + Utility class that can be used to modify state returned by the dataloader + """ + + def __init__(self, state_dict: Dict[str, Any]): + self._state_dict = state_dict + self._is_multiprocess_state = "_snapshot" in self._state_dict + + def fetch_dataset_state(self) -> Dict[int, Any]: + # Handle both cases of single process and multiprocess + if not self._is_multiprocess_state: + return self._state_dict["dataset_state"] + return { + state["worker_id"]: state["dataset_state"] + for _, state in self._state_dict["_snapshot"]["_worker_snapshots"].items() + } + + def set_last_worker_yielded_id(self, last_worker_yielded: int) -> None: + # Ensure that this number is within the number of workers + if not self._is_multiprocess_state: + logger.warning("Cannot set last worker yielded id on a single process state dict") + return + self._state_dict["_snapshot"]["_last_yielded_worker_id"] = last_worker_yielded + + def set_num_workers(self, num_workers: int) -> None: + if not self._is_multiprocess_state: + logger.warning("Cannot set num_workers on a single process state dict") + return + self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"] = num_workers + + def set_dataset_state(self, dataset_state: Union[Dict[int, Any], Any]) -> None: + if not self._is_multiprocess_state: + self._state_dict["dataset_state"] = dataset_state + return + + for id, state in dataset_state.items(): + worker_states = self._state_dict["_snapshot"]["_worker_snapshots"] + worker_key = f"worker_{id}" + if worker_key in worker_states: + worker_states[worker_key]["dataset_state"] = state + else: + worker_states[worker_key] = {"worker_id": id, "dataset_state": state, "fetcher_state": None} + + def get_state_dict(self) -> Dict[str, Any]: + # Perform validations + # a) num_workers should match worker_snapshots + # b) last yielded worker id should be within num_workers + if not self._is_multiprocess_state: + return self._state_dict + + last_yielded_id = self._state_dict["_snapshot"]["_last_yielded_worker_id"] + num_workers = self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"] + worker_ids = self._state_dict["_snapshot"]["_worker_snapshots"].keys() + + assert ( + len(worker_ids) == num_workers + ), f"Number of worker states {len(worker_ids)} should be equal to num_workers setting {num_workers}" + assert ( + len(set(worker_ids)) == num_workers + ), f"Worker state for all from [0, {num_workers}) should be present. Instead found state for only {worker_ids} workers" + assert last_yielded_id < num_workers, "Last yielded id should be strictly within the number of workers" + return self._state_dict