|
1 | 1 | # Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +import contextlib |
2 | 3 | import copy |
3 | 4 | import itertools |
4 | 5 | import logging |
5 | 6 | import numpy as np |
6 | 7 | import pickle |
7 | 8 | import random |
| 9 | +from typing import Callable, Union |
8 | 10 | import torch.utils.data as data |
9 | 11 | from torch.utils.data.sampler import Sampler |
10 | 12 |
|
11 | 13 | from detectron2.utils.serialize import PicklableWrapper |
12 | 14 |
|
13 | 15 | __all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"] |
14 | 16 |
|
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
15 | 19 |
|
16 | 20 | def _shard_iterator_dataloader_worker(iterable): |
17 | 21 | # Shard the iterable if we're currently inside pytorch dataloader worker. |
@@ -106,56 +110,101 @@ def __getitem__(self, idx): |
106 | 110 | ) |
107 | 111 |
|
108 | 112 |
|
| 113 | +class NumpySerializedList(object): |
| 114 | + """ |
| 115 | + A list-like object whose items are serialized and stored in a Numpy Array. When |
| 116 | + forking a process that has NumpySerializedList, subprocesses can read the same list |
| 117 | + without triggering copy-on-access, therefore they will share RAM for the list. This |
| 118 | + avoids the issue in https://github.com/pytorch/pytorch/issues/13246 |
| 119 | + """ |
| 120 | + |
| 121 | + def __init__(self, lst: list): |
| 122 | + self._lst = lst |
| 123 | + |
| 124 | + def _serialize(data): |
| 125 | + buffer = pickle.dumps(data, protocol=-1) |
| 126 | + return np.frombuffer(buffer, dtype=np.uint8) |
| 127 | + |
| 128 | + logger.info( |
| 129 | + "Serializing {} elements to byte tensors and concatenating them all ...".format( |
| 130 | + len(self._lst) |
| 131 | + ) |
| 132 | + ) |
| 133 | + self._lst = [_serialize(x) for x in self._lst] |
| 134 | + self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) |
| 135 | + self._addr = np.cumsum(self._addr) |
| 136 | + self._lst = np.concatenate(self._lst) |
| 137 | + logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) |
| 138 | + |
| 139 | + def __len__(self): |
| 140 | + return len(self._addr) |
| 141 | + |
| 142 | + def __getitem__(self, idx): |
| 143 | + start_addr = 0 if idx == 0 else self._addr[idx - 1].item() |
| 144 | + end_addr = self._addr[idx].item() |
| 145 | + bytes = memoryview(self._lst[start_addr:end_addr]) |
| 146 | + |
| 147 | + # @lint-ignore PYTHONPICKLEISBAD |
| 148 | + return pickle.loads(bytes) |
| 149 | + |
| 150 | + |
| 151 | +_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = NumpySerializedList |
| 152 | + |
| 153 | + |
| 154 | +@contextlib.contextmanager |
| 155 | +def set_default_dataset_from_list_serialize_method(new): |
| 156 | + """ |
| 157 | + Context manager for using custom serialize function when creating DatasetFromList |
| 158 | + """ |
| 159 | + |
| 160 | + global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD |
| 161 | + orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD |
| 162 | + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new |
| 163 | + yield |
| 164 | + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig |
| 165 | + |
| 166 | + |
109 | 167 | class DatasetFromList(data.Dataset): |
110 | 168 | """ |
111 | 169 | Wrap a list to a torch Dataset. It produces elements of the list as data. |
112 | 170 | """ |
113 | 171 |
|
114 | | - def __init__(self, lst: list, copy: bool = True, serialize: bool = True): |
| 172 | + def __init__( |
| 173 | + self, |
| 174 | + lst: list, |
| 175 | + copy: bool = True, |
| 176 | + serialize: Union[bool, Callable] = True, |
| 177 | + ): |
115 | 178 | """ |
116 | 179 | Args: |
117 | 180 | lst (list): a list which contains elements to produce. |
118 | 181 | copy (bool): whether to deepcopy the element when producing it, |
119 | 182 | so that the result can be modified in place without affecting the |
120 | 183 | source in the list. |
121 | | - serialize (bool): whether to hold memory using serialized objects, when |
122 | | - enabled, data loader workers can use shared RAM from master |
123 | | - process instead of making a copy. |
| 184 | + serialize (bool or callable): whether to serialize the stroage to other |
| 185 | + backend. If `True`, the default serialize method will be used, if given |
| 186 | + a callable, the callable will be used as serialize method. |
124 | 187 | """ |
125 | 188 | self._lst = lst |
126 | 189 | self._copy = copy |
127 | | - self._serialize = serialize |
128 | | - |
129 | | - def _serialize(data): |
130 | | - buffer = pickle.dumps(data, protocol=-1) |
131 | | - return np.frombuffer(buffer, dtype=np.uint8) |
| 190 | + if not isinstance(serialize, (bool, Callable)): |
| 191 | + raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}") |
| 192 | + self._serialize = serialize is not False |
132 | 193 |
|
133 | 194 | if self._serialize: |
134 | | - logger = logging.getLogger(__name__) |
135 | | - logger.info( |
136 | | - "Serializing {} elements to byte tensors and concatenating them all ...".format( |
137 | | - len(self._lst) |
138 | | - ) |
| 195 | + serialize_method = ( |
| 196 | + serialize |
| 197 | + if isinstance(serialize, Callable) |
| 198 | + else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD |
139 | 199 | ) |
140 | | - self._lst = [_serialize(x) for x in self._lst] |
141 | | - self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) |
142 | | - self._addr = np.cumsum(self._addr) |
143 | | - self._lst = np.concatenate(self._lst) |
144 | | - logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) |
| 200 | + logger.info(f"Serializing the dataset using: {serialize_method}") |
| 201 | + self._lst = serialize_method(self._lst) |
145 | 202 |
|
146 | 203 | def __len__(self): |
147 | | - if self._serialize: |
148 | | - return len(self._addr) |
149 | | - else: |
150 | | - return len(self._lst) |
| 204 | + return len(self._lst) |
151 | 205 |
|
152 | 206 | def __getitem__(self, idx): |
153 | | - if self._serialize: |
154 | | - start_addr = 0 if idx == 0 else self._addr[idx - 1].item() |
155 | | - end_addr = self._addr[idx].item() |
156 | | - bytes = memoryview(self._lst[start_addr:end_addr]) |
157 | | - return pickle.loads(bytes) |
158 | | - elif self._copy: |
| 207 | + if self._copy and not self._serialize: |
159 | 208 | return copy.deepcopy(self._lst[idx]) |
160 | 209 | else: |
161 | 210 | return self._lst[idx] |
|
0 commit comments