Skip to content

Commit fc59c8a

Browse files
Feat: add multiple transform_fn support in StreamingDataset (Lightning-AI#655)
* feat: add support to pass transform fn kwargs * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix pre-commit * update * update * update * remove transform-kwargs * update * updatd --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 371f262 commit fc59c8a

File tree

3 files changed

+68
-7
lines changed

3 files changed

+68
-7
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ if __name__ == "__main__":
935935

936936
Transform datasets on-the-fly while streaming them, allowing for efficient data processing without the need to store intermediate results.
937937

938-
- You can use the `transform` argument in `StreamingDataset` to apply a transformation function to each sample as it is streamed.
938+
- You can use the `transform` argument in `StreamingDataset` to apply a `transformation function` or `a list of transformation functions` to each sample as it is streamed.
939939

940940
```python
941941
# Define a simple transform function
@@ -953,7 +953,7 @@ def transform_fn(x, *args, **kwargs):
953953
return torch_transform(x) # Apply the transform to the input image
954954

955955
# Create dataset with appropriate configuration
956-
dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=shuffle, transform=transform_fn)
956+
dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=shuffle, transform=[transform_fn])
957957
```
958958

959959
Or, you can create a subclass of `StreamingDataset` and override its `transform` method to apply custom transformations to each sample.

src/litdata/streaming/dataset.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
max_pre_download: int = 2,
6363
index_path: Optional[str] = None,
6464
force_override_state_dict: bool = False,
65-
transform: Optional[Callable] = None,
65+
transform: Optional[Union[Callable, list[Callable]]] = None,
6666
) -> None:
6767
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
6868
@@ -89,7 +89,7 @@ def __init__(
8989
If `index_path` is a directory, the function will look for `index.json` within it.
9090
If `index_path` is a full file path, it will use that directly.
9191
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
92-
transform: Optional transformation function to apply to each item in the dataset.
92+
transform: Optional transformation function or list of functions to apply to each item in the dataset.
9393
"""
9494
_check_version_and_prompt_upgrade(__version__)
9595

@@ -198,8 +198,10 @@ def __init__(
198198
self.session_options = session_options
199199
self.max_pre_download = max_pre_download
200200
if transform is not None:
201-
if not callable(transform):
202-
raise ValueError(f"Transform should be a callable. Found {transform}")
201+
transform = transform if isinstance(transform, list) else [transform]
202+
for t in transform:
203+
if not callable(t):
204+
raise ValueError(f"Transform should be a callable. Found {t}")
203205
self.transform = transform
204206
self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache
205207

@@ -441,7 +443,14 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
441443
{"name": f"getitem_dataset_for_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "E"}
442444
)
443445
)
444-
return self.transform(item) if hasattr(self, "transform") else item
446+
if hasattr(self, "transform"):
447+
if isinstance(self.transform, list):
448+
for transform_fn in self.transform:
449+
item = transform_fn(item)
450+
else:
451+
item = self.transform(item)
452+
453+
return item
445454

446455
def __next__(self) -> Any:
447456
# check if we have reached the end of the dataset (i.e., all the chunks have been processed)

tests/streaming/test_dataset.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import random
1818
import shutil
1919
import sys
20+
from functools import partial
2021
from time import sleep
2122
from typing import Any, Optional
2223
from unittest import mock
@@ -1695,6 +1696,57 @@ def transform_fn(x, *args, **kwargs):
16951696
assert item == i * 2, f"Expected {i * 2}, got {item}"
16961697

16971698

1699+
@pytest.mark.parametrize("shuffle", [True, False])
1700+
def test_dataset_multiple_transform(tmpdir, shuffle):
1701+
"""Test if the dataset transform is applied correctly."""
1702+
# Create a simple dataset
1703+
# Create directories for cache and data
1704+
cache_dir = os.path.join(tmpdir, "cache_dir")
1705+
data_dir = os.path.join(tmpdir, "data_dir")
1706+
os.makedirs(cache_dir)
1707+
os.makedirs(data_dir)
1708+
1709+
# Create a dataset with 100 items, 20 items per chunk
1710+
cache = Cache(str(data_dir), chunk_size=20)
1711+
for i in range(100):
1712+
cache[i] = i
1713+
cache.done()
1714+
cache.merge()
1715+
1716+
# Define two simple transform function
1717+
def transform_fn_1(x):
1718+
"""A simple transform function that doubles the input."""
1719+
return x * 2
1720+
1721+
def transform_fn_2(x, extra_num):
1722+
"""A simple transform function that adds one to the input."""
1723+
return x + extra_num
1724+
1725+
dataset = StreamingDataset(
1726+
data_dir,
1727+
cache_dir=str(cache_dir),
1728+
shuffle=shuffle,
1729+
transform=[transform_fn_1, partial(transform_fn_2, extra_num=100)],
1730+
)
1731+
dataset_length = len(dataset)
1732+
assert dataset_length == 100
1733+
1734+
# ACT
1735+
# Stream through the entire dataset and store the results
1736+
complete_data = []
1737+
for data in dataset:
1738+
assert data is not None
1739+
complete_data.append(data)
1740+
1741+
if shuffle:
1742+
complete_data.sort()
1743+
1744+
# ASSERT
1745+
# Verify that the transform is applied correctly
1746+
for i, item in enumerate(complete_data):
1747+
assert item == i * 2 + 100, f"Expected {i * 2 + 100}, got {item}"
1748+
1749+
16981750
@pytest.mark.parametrize("shuffle", [True, False])
16991751
def test_dataset_transform_inheritance(tmpdir, shuffle):
17001752
"""Test if the dataset transform is applied correctly."""

0 commit comments

Comments
 (0)