Skip to content

Commit ebc3bcd

Browse files
tchatonthomas
authored andcommitted
Add human readable format for chunk_bytes (#18925)
Co-authored-by: thomas <[email protected]> (cherry picked from commit 37cbee4)
1 parent e211a93 commit ebc3bcd

File tree

6 files changed

+50
-13
lines changed

6 files changed

+50
-13
lines changed

src/lightning/data/streaming/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
version: Optional[Union[int, Literal["latest"]]] = "latest",
4242
compression: Optional[str] = None,
4343
chunk_size: Optional[int] = None,
44-
chunk_bytes: Optional[int] = None,
44+
chunk_bytes: Optional[Union[int, str]] = None,
4545
item_loader: Optional[BaseItemLoader] = None,
4646
):
4747
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements

src/lightning/data/streaming/data_processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,10 @@ def _done(self, delete_cached_files: bool, remote_output_dir: Any) -> None:
561561

562562
class DataChunkRecipe(DataRecipe):
563563
def __init__(
564-
self, chunk_size: Optional[int] = None, chunk_bytes: Optional[int] = None, compression: Optional[str] = None
564+
self,
565+
chunk_size: Optional[int] = None,
566+
chunk_bytes: Optional[Union[int, str]] = None,
567+
compression: Optional[str] = None,
565568
):
566569
super().__init__()
567570
if chunk_size is not None and chunk_bytes is not None:

src/lightning/data/streaming/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
fn: Callable[[Any], None],
6767
inputs: Sequence[Any],
6868
chunk_size: Optional[int],
69-
chunk_bytes: Optional[int],
69+
chunk_bytes: Optional[Union[int, str]],
7070
compression: Optional[str],
7171
):
7272
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
@@ -141,7 +141,7 @@ def optimize(
141141
inputs: Sequence[Any],
142142
output_dir: str,
143143
chunk_size: Optional[int] = None,
144-
chunk_bytes: Optional[int] = None,
144+
chunk_bytes: Optional[Union[int, str]] = None,
145145
compression: Optional[str] = None,
146146
name: Optional[str] = None,
147147
num_workers: Optional[int] = None,

src/lightning/data/streaming/writer.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
from dataclasses import dataclass
1717
from time import sleep
18-
from typing import Any, Dict, List, Optional, Tuple
18+
from typing import Any, Dict, List, Optional, Tuple, Union
1919

2020
import numpy as np
2121
import torch
@@ -29,11 +29,35 @@
2929
from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
3030

3131

32-
def _get_data_optimizer_node_rank() -> Optional[int]:
33-
node_rank = os.getenv("DATA_OPTIMIZER_NODE_RANK", None)
34-
if node_rank is not None:
35-
return int(node_rank)
36-
return node_rank
32+
_FORMAT_TO_RATIO = {
33+
"kb": 1024,
34+
"mb": 1024**2,
35+
"gb": 1024**3,
36+
"tb": 1024**4,
37+
"pb": 1024**5,
38+
"eb": 1024**6,
39+
"zb": 1024**7,
40+
"yb": 1024**8,
41+
}
42+
43+
44+
def _convert_bytes_to_int(bytes_str: str) -> int:
45+
"""Convert human readable byte format to an integer."""
46+
for suffix in _FORMAT_TO_RATIO:
47+
bytes_str = bytes_str.lower().strip()
48+
if bytes_str.lower().endswith(suffix):
49+
try:
50+
return int(float(bytes_str[0 : -len(suffix)]) * _FORMAT_TO_RATIO[suffix])
51+
except ValueError:
52+
raise ValueError(
53+
"".join(
54+
[
55+
f"Unsupported value/suffix {bytes_str}. Supported suffix are ",
56+
f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.',
57+
]
58+
)
59+
)
60+
raise ValueError(f"The supported units are {_FORMAT_TO_RATIO.keys()}")
3761

3862

3963
@dataclass
@@ -52,7 +76,7 @@ def __init__(
5276
self,
5377
cache_dir: str,
5478
chunk_size: Optional[int] = None,
55-
chunk_bytes: Optional[int] = None,
79+
chunk_bytes: Optional[Union[int, str]] = None,
5680
compression: Optional[str] = None,
5781
follow_tensor_dimension: bool = True,
5882
):
@@ -75,7 +99,7 @@ def __init__(
7599

76100
self._serializers: Dict[str, Serializer] = _SERIALIZERS
77101
self._chunk_size = chunk_size
78-
self._chunk_bytes = chunk_bytes
102+
self._chunk_bytes = _convert_bytes_to_int(chunk_bytes) if isinstance(chunk_bytes, str) else chunk_bytes
79103
self._compression = compression
80104

81105
self._data_format: Optional[List[str]] = None

tests/tests_data/datasets/test_iterable.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def sharding_resume_test(fabric: lightning.Fabric, num_workers):
437437
fabric.barrier()
438438

439439

440+
@pytest.mark.skipif(True, reason="flaky and out-dated")
440441
@pytest.mark.parametrize(
441442
("num_workers", "world_size"),
442443
[

tests/tests_data/streaming/test_writer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lightning import seed_everything
2020
from lightning.data.streaming.reader import BinaryReader
2121
from lightning.data.streaming.sampler import ChunkedIndex
22-
from lightning.data.streaming.writer import BinaryWriter
22+
from lightning.data.streaming.writer import _FORMAT_TO_RATIO, BinaryWriter
2323
from lightning_utilities.core.imports import RequirementCache
2424

2525
_PIL_AVAILABLE = RequirementCache("PIL")
@@ -194,3 +194,12 @@ def test_binary_writer_with_jpeg_and_png(tmpdir):
194194

195195
with pytest.raises(ValueError, match="The data format changed between items"):
196196
binary_writer[2] = {"x": 2, "y": 1}
197+
198+
199+
def test_writer_human_format(tmpdir):
200+
for k, v in _FORMAT_TO_RATIO.items():
201+
binary_writer = BinaryWriter(tmpdir, chunk_bytes=f"{1}{k}")
202+
assert binary_writer._chunk_bytes == v
203+
204+
binary_writer = BinaryWriter(tmpdir, chunk_bytes="64MB")
205+
assert binary_writer._chunk_bytes == 67108864

0 commit comments

Comments
 (0)