@@ -25,11 +25,35 @@ Below we showcase Lightning examples with packages that compete with the generic
25
25
faster depending on your use case. They might require custom data serialization, loading, and preprocessing that
26
26
is often hardware accelerated.
27
27
28
- .. TODO(carmocca)
29
- StreamingDataset
30
- ^^^^^^^^^^^^^^^^
28
+ StreamingDataset
29
+ ^^^^^^^^^^^^^^^^
31
30
32
- The `StreamingDataset <https://github.com/mosaicml/streaming>`__
31
+ As datasets grow in size and the number of nodes scales, loading training data can become a significant challenge.
32
+ The `StreamingDataset <https://github.com/mosaicml/streaming >`__ can make training on large datasets from cloud storage
33
+ as fast, cheap, and scalable as possible.
34
+
35
+ This library uses a custom built class:`~torch.utils.data.IterableDataset `. The library recommends iterating through it
36
+ via a regular class:`~torch.utils.data.DataLoader `. This means that support in the ``Trainer `` is seamless:
37
+
38
+ .. code-block :: python
39
+
40
+ import lightning as L
41
+ from streaming import MDSWriter, StreamingDataset
42
+
43
+
44
+ class YourDataset (StreamingDataset ):
45
+ ...
46
+
47
+
48
+ # you could do this in the `prepare_data` hook too
49
+ with MDSWriter(out = " ..." , columns = ... ) as out:
50
+ out.write(... )
51
+
52
+ train_dataset = YourDataset()
53
+ train_dataloader = DataLoader(train_dataset, batch_size = batch_size)
54
+ model = ...
55
+ trainer = L.Trainer()
56
+ trainer.fit(model, train_dataloader)
33
57
34
58
FFCV
35
59
^^^^
@@ -42,36 +66,47 @@ the desired GPU in your pipeline. When moving data to a specific device, you can
42
66
43
67
.. code-block :: python
44
68
69
+ import lightning as L
45
70
from ffcv.loader import Loader, OrderOption
46
71
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout
47
72
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder
48
73
74
+ # Random resized crop
75
+ decoder = RandomResizedCropRGBImageDecoder((224 , 224 ))
76
+ # Data decoding and augmentation
77
+ image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage()]
78
+ label_pipeline = [IntDecoder(), ToTensor()]
79
+ # Pipeline for each data field
80
+ pipelines = {" image" : image_pipeline, " label" : label_pipeline}
81
+ # Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
82
+ train_dataloader = Loader(
83
+ write_path, batch_size = bs, num_workers = num_workers, order = OrderOption.RANDOM , pipelines = pipelines
84
+ )
85
+
86
+ model = ...
87
+ trainer = L.Trainer()
88
+ trainer.fit(model, train_dataloader)
89
+
90
+ WebDataset
91
+ ^^^^^^^^^^
92
+
93
+ The `WebDataset <https://webdataset.github.io/webdataset >`__ makes it easy to write I/O pipelines for large datasets.
94
+ Datasets can be stored locally or in the cloud. ``WebDataset `` is just an instance of a standard IterableDataset.
95
+ The webdataset library contains a small wrapper (``WebLoader ``) that adds a fluid interface to the DataLoader (and is otherwise identical).
49
96
50
- class CustomClassifier (LitClassifier ):
51
- def train_dataloader (self ):
52
- # Random resized crop
53
- decoder = RandomResizedCropRGBImageDecoder((224 , 224 ))
54
-
55
- # Data decoding and augmentation
56
- image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage()]
57
- label_pipeline = [IntDecoder(), ToTensor()]
58
-
59
- # Pipeline for each data field
60
- pipelines = {" image" : image_pipeline, " label" : label_pipeline}
61
-
62
- # Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
63
- loader = Loader(
64
- write_path, batch_size = bs, num_workers = num_workers, order = OrderOption.RANDOM , pipelines = pipelines
65
- )
97
+ .. code-block :: python
66
98
67
- return loader
99
+ import lightning as L
100
+ import webdataset as wds
68
101
102
+ dataset = wds.WebDataset(urls)
103
+ train_dataloader = wds.WebLoader(dataset)
69
104
70
- .. TODO(carmocca)
71
- WebDataset
72
- ^^^^^^^^^^
105
+ model = ...
106
+ trainer = L.Trainer()
107
+ trainer.fit(model, train_dataloader)
73
108
74
- The `WebDataset <https://webdataset. github.io /webdataset>`__
109
+ You can find a complete example ` here <https://github.com /webdataset/webdataset-lightning >`__.
75
110
76
111
NVIDIA DALI
77
112
^^^^^^^^^^^
@@ -80,44 +115,48 @@ By just changing ``device_id=0`` to ``device_id=self.trainer.local_rank`` we can
80
115
81
116
.. code-block :: python
82
117
83
- from nvidia.dali.pipeline import pipeline_def
84
- import nvidia.dali.types as types
85
- import nvidia.dali.fn as fn
86
- from nvidia.dali.plugin.pytorch import DALIGenericIterator
87
- import os
88
-
89
-
90
- class CustomLitClassifier (LitClassifier ):
91
- def train_dataloader (self ):
92
- # To run with different data, see documentation of nvidia.dali.fn.readers.file
93
- # points to https://github.com/NVIDIA/DALI_extra
94
- data_root_dir = os.environ[" DALI_EXTRA_PATH" ]
95
- images_dir = os.path.join(data_root_dir, " db" , " single" , " jpeg" )
96
-
97
- @pipeline_def (num_threads = 4 , device_id = self .trainer.local_rank)
98
- def get_dali_pipeline ():
99
- images, labels = fn.readers.file(file_root = images_dir, random_shuffle = True , name = " Reader" )
100
- # decode data on the GPU
101
- images = fn.decoders.image_random_crop(images, device = " mixed" , output_type = types.RGB )
102
- # the rest of processing happens on the GPU as well
103
- images = fn.resize(images, resize_x = 256 , resize_y = 256 )
104
- images = fn.crop_mirror_normalize(
105
- images,
106
- crop_h = 224 ,
107
- crop_w = 224 ,
108
- mean = [0.485 * 255 , 0.456 * 255 , 0.406 * 255 ],
109
- std = [0.229 * 255 , 0.224 * 255 , 0.225 * 255 ],
110
- mirror = fn.random.coin_flip(),
111
- )
112
- return images, labels
113
-
114
- train_data = DALIGenericIterator(
115
- [get_dali_pipeline(batch_size = 16 )],
116
- [" data" , " label" ],
117
- reader_name = " Reader" ,
118
- )
119
-
120
- return train_data
118
+ import lightning as L
119
+ from nvidia.dali.pipeline import pipeline_def
120
+ import nvidia.dali.types as types
121
+ import nvidia.dali.fn as fn
122
+ from nvidia.dali.plugin.pytorch import DALIGenericIterator
123
+ import os
124
+
125
+ # To run with different data, see documentation of nvidia.dali.fn.readers.file
126
+ # points to https://github.com/NVIDIA/DALI_extra
127
+ data_root_dir = os.environ[" DALI_EXTRA_PATH" ]
128
+ images_dir = os.path.join(data_root_dir, " db" , " single" , " jpeg" )
129
+
130
+
131
+ @pipeline_def (num_threads = 4 , device_id = self .trainer.local_rank)
132
+ def get_dali_pipeline ():
133
+ images, labels = fn.readers.file(file_root = images_dir, random_shuffle = True , name = " Reader" )
134
+ # decode data on the GPU
135
+ images = fn.decoders.image_random_crop(images, device = " mixed" , output_type = types.RGB )
136
+ # the rest of processing happens on the GPU as well
137
+ images = fn.resize(images, resize_x = 256 , resize_y = 256 )
138
+ images = fn.crop_mirror_normalize(
139
+ images,
140
+ crop_h = 224 ,
141
+ crop_w = 224 ,
142
+ mean = [0.485 * 255 , 0.456 * 255 , 0.406 * 255 ],
143
+ std = [0.229 * 255 , 0.224 * 255 , 0.225 * 255 ],
144
+ mirror = fn.random.coin_flip(),
145
+ )
146
+ return images, labels
147
+
148
+
149
+ train_dataloader = DALIGenericIterator(
150
+ [get_dali_pipeline(batch_size = 16 )],
151
+ [" data" , " label" ],
152
+ reader_name = " Reader" ,
153
+ )
154
+
155
+ model = ...
156
+ trainer = L.Trainer()
157
+ trainer.fit(model, train_dataloader)
158
+
159
+ You can find a complete tutorial `here <https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/frameworks/pytorch/pytorch-lightning.html >`__.
121
160
122
161
123
162
Limitations
0 commit comments