Skip to content

Commit ab76989

Browse files
awaelchlipre-commit-ci[bot]
authored andcommitted
Fix oversized items not fitting into a chunk (#18938)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 0e7a3b0)
1 parent 201fb4d commit ab76989

File tree

5 files changed

+77
-18
lines changed

5 files changed

+77
-18
lines changed

src/lightning/data/streaming/writer.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import json
1515
import os
16+
import warnings
1617
from dataclasses import dataclass
1718
from time import sleep
1819
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -24,6 +25,7 @@
2425
from lightning.data.streaming.compression import _COMPRESSORS, Compressor
2526
from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
2627
from lightning.data.streaming.serializers import _SERIALIZERS, Serializer
28+
from lightning.data.utilities.format import _human_readable_bytes
2729

2830
if _TORCH_GREATER_EQUAL_2_1_0:
2931
from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
@@ -42,20 +44,16 @@
4244

4345

4446
def _convert_bytes_to_int(bytes_str: str) -> int:
45-
"""Convert human readable byte format to an integer."""
47+
"""Convert human-readable byte format to an integer."""
4648
for suffix in _FORMAT_TO_RATIO:
4749
bytes_str = bytes_str.lower().strip()
4850
if bytes_str.lower().endswith(suffix):
4951
try:
5052
return int(float(bytes_str[0 : -len(suffix)]) * _FORMAT_TO_RATIO[suffix])
5153
except ValueError:
5254
raise ValueError(
53-
"".join(
54-
[
55-
f"Unsupported value/suffix {bytes_str}. Supported suffix are ",
56-
f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.',
57-
]
58-
)
55+
f"Unsupported value/suffix {bytes_str}. Supported suffix are "
56+
f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.'
5957
)
6058
raise ValueError(f"The supported units are {_FORMAT_TO_RATIO.keys()}")
6159

@@ -212,39 +210,42 @@ def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str:
212210

213211
def _create_chunk(self, filename: str, on_done: bool = False) -> bytes:
214212
"""Create a binary chunk from all the binarized items."""
213+
items = []
214+
215215
if on_done:
216216
indices = sorted(self._serialized_items.keys())
217217
for i in range(len(indices) - 1):
218218
assert indices[i] == indices[i + 1] - 1, indices
219-
min_index = indices[0]
220-
max_index = indices[-1] + 1
221-
num_items = np.uint32(max_index - min_index)
222219
items = [self._serialized_items.pop(index) for index in indices]
223220
else:
224221
assert self._max_index is not None, (self._max_index, self._min_index)
225222
assert self._min_index is not None, (self._max_index, self._min_index)
226-
num_items = np.uint32(self._max_index - self._min_index)
227-
items = [self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)]
228-
min_index = self._min_index
229-
max_index = self._max_index
223+
if self._max_index == self._min_index:
224+
# A single item is larger than the target chunk size; allow the chunk to be bigger than the target size
225+
items.append(self._serialized_items.pop(self._max_index))
226+
items.extend(self._serialized_items.pop(index) for index in range(self._min_index, self._max_index))
230227

231228
if len(items) == 0:
232229
raise RuntimeError(
233230
"The items shouldn't have an empty length. Something went wrong."
234231
f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}."
235232
)
236233

234+
num_items = np.uint32(len(items))
237235
sizes = list(map(len, items))
238236
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
239237
offsets += len(num_items.tobytes()) + len(offsets.tobytes())
240238
sample_data = b"".join([item.data for item in items])
241239
data = num_items.tobytes() + offsets.tobytes() + sample_data
242-
offsets = offsets.tolist()
243240

244241
current_chunk_bytes = sum([item.bytes for item in items])
245242

246-
if self._chunk_bytes:
247-
assert current_chunk_bytes <= self._chunk_bytes
243+
if self._chunk_bytes and current_chunk_bytes > self._chunk_bytes:
244+
warnings.warn(
245+
f"An item was larger than the target chunk size ({_human_readable_bytes(self._chunk_bytes)})."
246+
f" The current chunk will be {_human_readable_bytes(current_chunk_bytes)} in size.",
247+
UserWarning,
248+
)
248249

249250
if self._chunk_size:
250251
assert num_items.item() <= self._chunk_size
@@ -308,6 +309,7 @@ def add_item(self, index: int, items: Any) -> Optional[str]:
308309
return filepath
309310

310311
def _should_write(self) -> bool:
312+
# TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`!
311313
if not self._serialized_items:
312314
return False
313315
indexes = list(self._serialized_items.keys())

src/lightning/data/utilities/__init__.py

Whitespace-only changes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def _human_readable_bytes(num_bytes: float) -> str:
2+
for unit in ("B", "KB", "MB", "GB", "TB"):
3+
if abs(num_bytes) < 1000.0:
4+
return f"{num_bytes:3.1f} {unit}"
5+
num_bytes /= 1000.0
6+
return f"{num_bytes:.1f} PB"

tests/tests_data/streaming/test_cache.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13-
13+
import json
1414
import os
1515
import sys
1616
from functools import partial
@@ -27,6 +27,7 @@
2727
from lightning.fabric import Fabric
2828
from lightning.pytorch.demos.boring_classes import RandomDataset
2929
from lightning_utilities.core.imports import RequirementCache
30+
from lightning_utilities.test.warning import no_warning_call
3031
from torch.utils.data import DataLoader, Dataset
3132

3233
_PIL_AVAILABLE = RequirementCache("PIL")
@@ -242,3 +243,35 @@ def test_streaming_dataset(tmpdir, monkeypatch):
242243

243244
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
244245
assert len(dataloader) == 408
246+
247+
248+
def test_create_oversized_chunk_single_item(tmp_path):
249+
cache = Cache(str(tmp_path), chunk_bytes=700)
250+
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):
251+
cache[0] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8)
252+
253+
254+
def test_create_undersized_and_oversized_chunk(tmp_path):
255+
cache = Cache(str(tmp_path), chunk_bytes=9000) # target: 9KB chunks
256+
with no_warning_call(UserWarning):
257+
cache[0] = np.random.randint(0, 10, size=(500,), dtype=np.uint8) # will result in undersized chunk
258+
cache[1] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8) # will result in oversized chunk
259+
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):
260+
cache[2] = np.random.randint(0, 10, size=(150,), dtype=np.uint8)
261+
with no_warning_call(UserWarning):
262+
cache[3] = np.random.randint(0, 10, size=(200,), dtype=np.uint8)
263+
264+
cache.done()
265+
cache.merge()
266+
267+
assert len(os.listdir(tmp_path)) == 4 # 3 chunks + 1 index file
268+
with open(tmp_path / "index.json") as file:
269+
index = json.load(file)
270+
271+
chunks = index["chunks"]
272+
assert chunks[0]["chunk_size"] == 1
273+
assert chunks[0]["filename"] == "chunk-0-0.bin"
274+
assert chunks[1]["chunk_size"] == 1
275+
assert chunks[1]["filename"] == "chunk-0-1.bin"
276+
assert chunks[2]["chunk_size"] == 2
277+
assert chunks[2]["filename"] == "chunk-0-2.bin"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from lightning.data.utilities.format import _human_readable_bytes
2+
3+
4+
def test_human_readable_bytes():
5+
assert _human_readable_bytes(0) == "0.0 B"
6+
assert _human_readable_bytes(1) == "1.0 B"
7+
assert _human_readable_bytes(999) == "999.0 B"
8+
assert _human_readable_bytes(int(1e3)) == "1.0 KB"
9+
assert _human_readable_bytes(int(1e3 + 1e2)) == "1.1 KB"
10+
assert _human_readable_bytes(int(1e6)) == "1.0 MB"
11+
assert _human_readable_bytes(int(1e6 + 2e5)) == "1.2 MB"
12+
assert _human_readable_bytes(int(1e9)) == "1.0 GB"
13+
assert _human_readable_bytes(int(1e9 + 3e8)) == "1.3 GB"
14+
assert _human_readable_bytes(int(1e12)) == "1.0 TB"
15+
assert _human_readable_bytes(int(1e12 + 4e11)) == "1.4 TB"
16+
assert _human_readable_bytes(int(1e15)) == "1.0 PB"
17+
assert _human_readable_bytes(int(1e15 + 5e14)) == "1.5 PB"
18+
assert _human_readable_bytes(int(1e18)) == "1000.0 PB"

0 commit comments

Comments
 (0)