Skip to content

Commit 4d9b845

Browse files
authored
Update docs for alternative dataset projects (#17096)
1 parent 1ddafcf commit 4d9b845

File tree

1 file changed

+102
-63
lines changed

1 file changed

+102
-63
lines changed

docs/source-pytorch/data/alternatives.rst

Lines changed: 102 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,35 @@ Below we showcase Lightning examples with packages that compete with the generic
2525
faster depending on your use case. They might require custom data serialization, loading, and preprocessing that
2626
is often hardware accelerated.
2727

28-
.. TODO(carmocca)
29-
StreamingDataset
30-
^^^^^^^^^^^^^^^^
28+
StreamingDataset
29+
^^^^^^^^^^^^^^^^
3130

32-
The `StreamingDataset <https://github.com/mosaicml/streaming>`__
31+
As datasets grow in size and the number of nodes scales, loading training data can become a significant challenge.
32+
The `StreamingDataset <https://github.com/mosaicml/streaming>`__ can make training on large datasets from cloud storage
33+
as fast, cheap, and scalable as possible.
34+
35+
This library uses a custom built class:`~torch.utils.data.IterableDataset`. The library recommends iterating through it
36+
via a regular class:`~torch.utils.data.DataLoader`. This means that support in the ``Trainer`` is seamless:
37+
38+
.. code-block:: python
39+
40+
import lightning as L
41+
from streaming import MDSWriter, StreamingDataset
42+
43+
44+
class YourDataset(StreamingDataset):
45+
...
46+
47+
48+
# you could do this in the `prepare_data` hook too
49+
with MDSWriter(out="...", columns=...) as out:
50+
out.write(...)
51+
52+
train_dataset = YourDataset()
53+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
54+
model = ...
55+
trainer = L.Trainer()
56+
trainer.fit(model, train_dataloader)
3357
3458
FFCV
3559
^^^^
@@ -42,36 +66,47 @@ the desired GPU in your pipeline. When moving data to a specific device, you can
4266

4367
.. code-block:: python
4468
69+
import lightning as L
4570
from ffcv.loader import Loader, OrderOption
4671
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout
4772
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder
4873
74+
# Random resized crop
75+
decoder = RandomResizedCropRGBImageDecoder((224, 224))
76+
# Data decoding and augmentation
77+
image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage()]
78+
label_pipeline = [IntDecoder(), ToTensor()]
79+
# Pipeline for each data field
80+
pipelines = {"image": image_pipeline, "label": label_pipeline}
81+
# Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
82+
train_dataloader = Loader(
83+
write_path, batch_size=bs, num_workers=num_workers, order=OrderOption.RANDOM, pipelines=pipelines
84+
)
85+
86+
model = ...
87+
trainer = L.Trainer()
88+
trainer.fit(model, train_dataloader)
89+
90+
WebDataset
91+
^^^^^^^^^^
92+
93+
The `WebDataset <https://webdataset.github.io/webdataset>`__ makes it easy to write I/O pipelines for large datasets.
94+
Datasets can be stored locally or in the cloud. ``WebDataset`` is just an instance of a standard IterableDataset.
95+
The webdataset library contains a small wrapper (``WebLoader``) that adds a fluid interface to the DataLoader (and is otherwise identical).
4996

50-
class CustomClassifier(LitClassifier):
51-
def train_dataloader(self):
52-
# Random resized crop
53-
decoder = RandomResizedCropRGBImageDecoder((224, 224))
54-
55-
# Data decoding and augmentation
56-
image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage()]
57-
label_pipeline = [IntDecoder(), ToTensor()]
58-
59-
# Pipeline for each data field
60-
pipelines = {"image": image_pipeline, "label": label_pipeline}
61-
62-
# Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
63-
loader = Loader(
64-
write_path, batch_size=bs, num_workers=num_workers, order=OrderOption.RANDOM, pipelines=pipelines
65-
)
97+
.. code-block:: python
6698
67-
return loader
99+
import lightning as L
100+
import webdataset as wds
68101
102+
dataset = wds.WebDataset(urls)
103+
train_dataloader = wds.WebLoader(dataset)
69104
70-
.. TODO(carmocca)
71-
WebDataset
72-
^^^^^^^^^^
105+
model = ...
106+
trainer = L.Trainer()
107+
trainer.fit(model, train_dataloader)
73108
74-
The `WebDataset <https://webdataset.github.io/webdataset>`__
109+
You can find a complete example `here <https://github.com/webdataset/webdataset-lightning>`__.
75110

76111
NVIDIA DALI
77112
^^^^^^^^^^^
@@ -80,44 +115,48 @@ By just changing ``device_id=0`` to ``device_id=self.trainer.local_rank`` we can
80115

81116
.. code-block:: python
82117
83-
from nvidia.dali.pipeline import pipeline_def
84-
import nvidia.dali.types as types
85-
import nvidia.dali.fn as fn
86-
from nvidia.dali.plugin.pytorch import DALIGenericIterator
87-
import os
88-
89-
90-
class CustomLitClassifier(LitClassifier):
91-
def train_dataloader(self):
92-
# To run with different data, see documentation of nvidia.dali.fn.readers.file
93-
# points to https://github.com/NVIDIA/DALI_extra
94-
data_root_dir = os.environ["DALI_EXTRA_PATH"]
95-
images_dir = os.path.join(data_root_dir, "db", "single", "jpeg")
96-
97-
@pipeline_def(num_threads=4, device_id=self.trainer.local_rank)
98-
def get_dali_pipeline():
99-
images, labels = fn.readers.file(file_root=images_dir, random_shuffle=True, name="Reader")
100-
# decode data on the GPU
101-
images = fn.decoders.image_random_crop(images, device="mixed", output_type=types.RGB)
102-
# the rest of processing happens on the GPU as well
103-
images = fn.resize(images, resize_x=256, resize_y=256)
104-
images = fn.crop_mirror_normalize(
105-
images,
106-
crop_h=224,
107-
crop_w=224,
108-
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
109-
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
110-
mirror=fn.random.coin_flip(),
111-
)
112-
return images, labels
113-
114-
train_data = DALIGenericIterator(
115-
[get_dali_pipeline(batch_size=16)],
116-
["data", "label"],
117-
reader_name="Reader",
118-
)
119-
120-
return train_data
118+
import lightning as L
119+
from nvidia.dali.pipeline import pipeline_def
120+
import nvidia.dali.types as types
121+
import nvidia.dali.fn as fn
122+
from nvidia.dali.plugin.pytorch import DALIGenericIterator
123+
import os
124+
125+
# To run with different data, see documentation of nvidia.dali.fn.readers.file
126+
# points to https://github.com/NVIDIA/DALI_extra
127+
data_root_dir = os.environ["DALI_EXTRA_PATH"]
128+
images_dir = os.path.join(data_root_dir, "db", "single", "jpeg")
129+
130+
131+
@pipeline_def(num_threads=4, device_id=self.trainer.local_rank)
132+
def get_dali_pipeline():
133+
images, labels = fn.readers.file(file_root=images_dir, random_shuffle=True, name="Reader")
134+
# decode data on the GPU
135+
images = fn.decoders.image_random_crop(images, device="mixed", output_type=types.RGB)
136+
# the rest of processing happens on the GPU as well
137+
images = fn.resize(images, resize_x=256, resize_y=256)
138+
images = fn.crop_mirror_normalize(
139+
images,
140+
crop_h=224,
141+
crop_w=224,
142+
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
143+
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
144+
mirror=fn.random.coin_flip(),
145+
)
146+
return images, labels
147+
148+
149+
train_dataloader = DALIGenericIterator(
150+
[get_dali_pipeline(batch_size=16)],
151+
["data", "label"],
152+
reader_name="Reader",
153+
)
154+
155+
model = ...
156+
trainer = L.Trainer()
157+
trainer.fit(model, train_dataloader)
158+
159+
You can find a complete tutorial `here <https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/frameworks/pytorch/pytorch-lightning.html>`__.
121160

122161

123162
Limitations

0 commit comments

Comments
 (0)