|
13 | 13 |
|
14 | 14 | import json |
15 | 15 | import os |
| 16 | +import warnings |
16 | 17 | from dataclasses import dataclass |
17 | 18 | from time import sleep |
18 | 19 | from typing import Any, Dict, List, Optional, Tuple, Union |
|
24 | 25 | from lightning.data.streaming.compression import _COMPRESSORS, Compressor |
25 | 26 | from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 |
26 | 27 | from lightning.data.streaming.serializers import _SERIALIZERS, Serializer |
| 28 | +from lightning.data.utilities.format import _human_readable_bytes |
27 | 29 |
|
28 | 30 | if _TORCH_GREATER_EQUAL_2_1_0: |
29 | 31 | from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps |
|
42 | 44 |
|
43 | 45 |
|
44 | 46 | def _convert_bytes_to_int(bytes_str: str) -> int: |
45 | | - """Convert human readable byte format to an integer.""" |
| 47 | + """Convert human-readable byte format to an integer.""" |
46 | 48 | for suffix in _FORMAT_TO_RATIO: |
47 | 49 | bytes_str = bytes_str.lower().strip() |
48 | 50 | if bytes_str.lower().endswith(suffix): |
49 | 51 | try: |
50 | 52 | return int(float(bytes_str[0 : -len(suffix)]) * _FORMAT_TO_RATIO[suffix]) |
51 | 53 | except ValueError: |
52 | 54 | raise ValueError( |
53 | | - "".join( |
54 | | - [ |
55 | | - f"Unsupported value/suffix {bytes_str}. Supported suffix are ", |
56 | | - f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.', |
57 | | - ] |
58 | | - ) |
| 55 | + f"Unsupported value/suffix {bytes_str}. Supported suffix are " |
| 56 | + f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.' |
59 | 57 | ) |
60 | 58 | raise ValueError(f"The supported units are {_FORMAT_TO_RATIO.keys()}") |
61 | 59 |
|
@@ -212,39 +210,42 @@ def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str: |
212 | 210 |
|
213 | 211 | def _create_chunk(self, filename: str, on_done: bool = False) -> bytes: |
214 | 212 | """Create a binary chunk from all the binarized items.""" |
| 213 | + items = [] |
| 214 | + |
215 | 215 | if on_done: |
216 | 216 | indices = sorted(self._serialized_items.keys()) |
217 | 217 | for i in range(len(indices) - 1): |
218 | 218 | assert indices[i] == indices[i + 1] - 1, indices |
219 | | - min_index = indices[0] |
220 | | - max_index = indices[-1] + 1 |
221 | | - num_items = np.uint32(max_index - min_index) |
222 | 219 | items = [self._serialized_items.pop(index) for index in indices] |
223 | 220 | else: |
224 | 221 | assert self._max_index is not None, (self._max_index, self._min_index) |
225 | 222 | assert self._min_index is not None, (self._max_index, self._min_index) |
226 | | - num_items = np.uint32(self._max_index - self._min_index) |
227 | | - items = [self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)] |
228 | | - min_index = self._min_index |
229 | | - max_index = self._max_index |
| 223 | + if self._max_index == self._min_index: |
| 224 | + # A single item is larger than the target chunk size; allow the chunk to be bigger than the target size |
| 225 | + items.append(self._serialized_items.pop(self._max_index)) |
| 226 | + items.extend(self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)) |
230 | 227 |
|
231 | 228 | if len(items) == 0: |
232 | 229 | raise RuntimeError( |
233 | 230 | "The items shouldn't have an empty length. Something went wrong." |
234 | 231 | f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}." |
235 | 232 | ) |
236 | 233 |
|
| 234 | + num_items = np.uint32(len(items)) |
237 | 235 | sizes = list(map(len, items)) |
238 | 236 | offsets = np.array([0] + sizes).cumsum().astype(np.uint32) |
239 | 237 | offsets += len(num_items.tobytes()) + len(offsets.tobytes()) |
240 | 238 | sample_data = b"".join([item.data for item in items]) |
241 | 239 | data = num_items.tobytes() + offsets.tobytes() + sample_data |
242 | | - offsets = offsets.tolist() |
243 | 240 |
|
244 | 241 | current_chunk_bytes = sum([item.bytes for item in items]) |
245 | 242 |
|
246 | | - if self._chunk_bytes: |
247 | | - assert current_chunk_bytes <= self._chunk_bytes |
| 243 | + if self._chunk_bytes and current_chunk_bytes > self._chunk_bytes: |
| 244 | + warnings.warn( |
| 245 | + f"An item was larger than the target chunk size ({_human_readable_bytes(self._chunk_bytes)})." |
| 246 | + f" The current chunk will be {_human_readable_bytes(current_chunk_bytes)} in size.", |
| 247 | + UserWarning, |
| 248 | + ) |
248 | 249 |
|
249 | 250 | if self._chunk_size: |
250 | 251 | assert num_items.item() <= self._chunk_size |
@@ -308,6 +309,7 @@ def add_item(self, index: int, items: Any) -> Optional[str]: |
308 | 309 | return filepath |
309 | 310 |
|
310 | 311 | def _should_write(self) -> bool: |
| 312 | + # TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`! |
311 | 313 | if not self._serialized_items: |
312 | 314 | return False |
313 | 315 | indexes = list(self._serialized_items.keys()) |
|
0 commit comments