|
13 | 13 |
|
14 | 14 | import json
|
15 | 15 | import os
|
| 16 | +import warnings |
16 | 17 | from dataclasses import dataclass
|
17 | 18 | from time import sleep
|
18 | 19 | from typing import Any, Dict, List, Optional, Tuple, Union
|
|
24 | 25 | from lightning.data.streaming.compression import _COMPRESSORS, Compressor
|
25 | 26 | from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
|
26 | 27 | from lightning.data.streaming.serializers import _SERIALIZERS, Serializer
|
| 28 | +from lightning.data.utilities.format import _human_readable_bytes |
27 | 29 |
|
28 | 30 | if _TORCH_GREATER_EQUAL_2_1_0:
|
29 | 31 | from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
|
|
42 | 44 |
|
43 | 45 |
|
44 | 46 | def _convert_bytes_to_int(bytes_str: str) -> int:
|
45 |
| - """Convert human readable byte format to an integer.""" |
| 47 | + """Convert human-readable byte format to an integer.""" |
46 | 48 | for suffix in _FORMAT_TO_RATIO:
|
47 | 49 | bytes_str = bytes_str.lower().strip()
|
48 | 50 | if bytes_str.lower().endswith(suffix):
|
49 | 51 | try:
|
50 | 52 | return int(float(bytes_str[0 : -len(suffix)]) * _FORMAT_TO_RATIO[suffix])
|
51 | 53 | except ValueError:
|
52 | 54 | raise ValueError(
|
53 |
| - "".join( |
54 |
| - [ |
55 |
| - f"Unsupported value/suffix {bytes_str}. Supported suffix are ", |
56 |
| - f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.', |
57 |
| - ] |
58 |
| - ) |
| 55 | + f"Unsupported value/suffix {bytes_str}. Supported suffix are " |
| 56 | + f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.' |
59 | 57 | )
|
60 | 58 | raise ValueError(f"The supported units are {_FORMAT_TO_RATIO.keys()}")
|
61 | 59 |
|
@@ -212,39 +210,42 @@ def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str:
|
212 | 210 |
|
213 | 211 | def _create_chunk(self, filename: str, on_done: bool = False) -> bytes:
|
214 | 212 | """Create a binary chunk from all the binarized items."""
|
| 213 | + items = [] |
| 214 | + |
215 | 215 | if on_done:
|
216 | 216 | indices = sorted(self._serialized_items.keys())
|
217 | 217 | for i in range(len(indices) - 1):
|
218 | 218 | assert indices[i] == indices[i + 1] - 1, indices
|
219 |
| - min_index = indices[0] |
220 |
| - max_index = indices[-1] + 1 |
221 |
| - num_items = np.uint32(max_index - min_index) |
222 | 219 | items = [self._serialized_items.pop(index) for index in indices]
|
223 | 220 | else:
|
224 | 221 | assert self._max_index is not None, (self._max_index, self._min_index)
|
225 | 222 | assert self._min_index is not None, (self._max_index, self._min_index)
|
226 |
| - num_items = np.uint32(self._max_index - self._min_index) |
227 |
| - items = [self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)] |
228 |
| - min_index = self._min_index |
229 |
| - max_index = self._max_index |
| 223 | + if self._max_index == self._min_index: |
| 224 | + # A single item is larger than the target chunk size; allow the chunk to be bigger than the target size |
| 225 | + items.append(self._serialized_items.pop(self._max_index)) |
| 226 | + items.extend(self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)) |
230 | 227 |
|
231 | 228 | if len(items) == 0:
|
232 | 229 | raise RuntimeError(
|
233 | 230 | "The items shouldn't have an empty length. Something went wrong."
|
234 | 231 | f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}."
|
235 | 232 | )
|
236 | 233 |
|
| 234 | + num_items = np.uint32(len(items)) |
237 | 235 | sizes = list(map(len, items))
|
238 | 236 | offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
|
239 | 237 | offsets += len(num_items.tobytes()) + len(offsets.tobytes())
|
240 | 238 | sample_data = b"".join([item.data for item in items])
|
241 | 239 | data = num_items.tobytes() + offsets.tobytes() + sample_data
|
242 |
| - offsets = offsets.tolist() |
243 | 240 |
|
244 | 241 | current_chunk_bytes = sum([item.bytes for item in items])
|
245 | 242 |
|
246 |
| - if self._chunk_bytes: |
247 |
| - assert current_chunk_bytes <= self._chunk_bytes |
| 243 | + if self._chunk_bytes and current_chunk_bytes > self._chunk_bytes: |
| 244 | + warnings.warn( |
| 245 | + f"An item was larger than the target chunk size ({_human_readable_bytes(self._chunk_bytes)})." |
| 246 | + f" The current chunk will be {_human_readable_bytes(current_chunk_bytes)} in size.", |
| 247 | + UserWarning, |
| 248 | + ) |
248 | 249 |
|
249 | 250 | if self._chunk_size:
|
250 | 251 | assert num_items.item() <= self._chunk_size
|
@@ -308,6 +309,7 @@ def add_item(self, index: int, items: Any) -> Optional[str]:
|
308 | 309 | return filepath
|
309 | 310 |
|
310 | 311 | def _should_write(self) -> bool:
|
| 312 | + # TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`! |
311 | 313 | if not self._serialized_items:
|
312 | 314 | return False
|
313 | 315 | indexes = list(self._serialized_items.keys())
|
|
0 commit comments