Skip to content

Commit 127a75c

Browse files
committed
add webdataset
1 parent c6b6c6d commit 127a75c

File tree

1 file changed

+211
-0
lines changed

1 file changed

+211
-0
lines changed

libauc/datasets/webdataset.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Adapted from https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
2+
3+
import logging
4+
import random
5+
from multiprocessing import Value
6+
from typing import Dict, Callable, Optional
7+
8+
from torch.utils.data import get_worker_info
9+
try:
10+
import webdataset as wds
11+
from webdataset.filters import _shuffle
12+
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
13+
except ImportError:
14+
raise ImportError("webdataset is not installed. Please install it by running `pip install webdataset`.")
15+
16+
17+
class SharedEpoch:
18+
"""Epoch number for distributed training"""
19+
def __init__(self, epoch: int = 0):
20+
self.shared_epoch = Value('i', epoch)
21+
22+
def set_value(self, epoch):
23+
self.shared_epoch.value = epoch
24+
25+
def get_value(self):
26+
return self.shared_epoch.value
27+
28+
29+
def filter_no_caption_or_no_image(sample):
30+
"""Check if sample has caption and image"""
31+
has_caption = ('txt' in sample)
32+
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)
33+
return has_caption and has_image
34+
35+
36+
def log_and_continue(exn):
37+
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
38+
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
39+
return True
40+
41+
42+
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
43+
"""Return function over iterator that groups key, value pairs into samples.
44+
45+
:param keys: function that splits the key into key and extension (base_plus_ext)
46+
:param lcase: convert suffixes to lower case (Default value = True)
47+
"""
48+
current_sample = None
49+
for filesample in data:
50+
assert isinstance(filesample, dict)
51+
fname, value = filesample["fname"], filesample["data"]
52+
prefix, suffix = keys(fname)
53+
if prefix is None:
54+
continue
55+
if lcase:
56+
suffix = suffix.lower()
57+
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
58+
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
59+
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
60+
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
61+
if valid_sample(current_sample):
62+
yield current_sample
63+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
64+
if suffixes is None or suffix in suffixes:
65+
current_sample[suffix] = value
66+
if valid_sample(current_sample):
67+
yield current_sample
68+
69+
70+
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
71+
"""A re-implementation of the webdataset impl with group_by_keys that doesn't throw"""
72+
streams = url_opener(src, handler=handler)
73+
files = tar_file_expander(streams, handler=handler)
74+
samples = group_by_keys_nothrow(files, handler=handler)
75+
return samples
76+
77+
78+
def pytorch_worker_seed(increment=0):
79+
"""Get dataloader worker seed from pytorch"""
80+
worker_info = get_worker_info()
81+
if worker_info is not None:
82+
# favour using the seed already created for pytorch dataloader workers if it exists
83+
seed = worker_info.seed
84+
if increment:
85+
# space out seed increments so they can't overlap across workers in different iterations
86+
seed += increment * max(1, worker_info.num_workers)
87+
return seed
88+
# fallback to wds rank based seed
89+
return wds.utils.pytorch_worker_seed()
90+
91+
92+
_SHARD_SHUFFLE_SIZE = 2000
93+
_SHARD_SHUFFLE_INITIAL = 500
94+
_SAMPLE_SHUFFLE_SIZE = 5000
95+
_SAMPLE_SHUFFLE_INITIAL = 1000
96+
97+
98+
class detshuffle2(wds.PipelineStage):
99+
"""Shuffle according to seed and epoch"""
100+
def __init__(
101+
self,
102+
bufsize=1000,
103+
initial=100,
104+
seed=0,
105+
epoch=-1,
106+
):
107+
self.bufsize = bufsize
108+
self.initial = initial
109+
self.seed = seed
110+
self.epoch = epoch
111+
112+
def run(self, src):
113+
if isinstance(self.epoch, SharedEpoch):
114+
epoch = self.epoch.get_value()
115+
else:
116+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
117+
# situation as different workers may wrap at different times (or not at all).
118+
self.epoch += 1
119+
epoch = self.epoch
120+
rng = random.Random()
121+
if self.seed < 0:
122+
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
123+
seed = pytorch_worker_seed(epoch)
124+
else:
125+
# This seed to be deterministic AND the same across all nodes/workers in each epoch
126+
seed = self.seed + epoch
127+
rng.seed(seed)
128+
return _shuffle(src, self.bufsize, self.initial, rng)
129+
130+
131+
class WebDataset(wds.DataPipeline):
132+
r"""
133+
An image-text dataset that is stored in webdataset format. For more information on webdataset format,
134+
refer to https://github.com/webdataset/webdataset.
135+
136+
Args:
137+
input_shards (str): Path to the dataset shards.
138+
is_train (bool): Whether the dataset is for training or evaluation.
139+
batch_size (int): Batch size per worker.
140+
preprocess_img (Callable): Function to preprocess the image.
141+
seed (int): Seed for shuffling the dataset.
142+
epoch (int): Start epoch number.
143+
tokenize (Optional[Callable]): Tokenizer function for the text data.
144+
return_index (bool): Whether to return the index of the data.
145+
"""
146+
def __init__(self,
147+
input_shards: str,
148+
is_train: bool,
149+
batch_size: int,
150+
preprocess_img: Callable,
151+
seed: int = 0,
152+
epoch: int = 0,
153+
tokenize: Optional[Callable] = None,
154+
return_index: bool = False,
155+
):
156+
self.shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
157+
pipeline = [wds.SimpleShardList(input_shards)]
158+
159+
# at this point we have an iterator over all the shards
160+
if is_train:
161+
pipeline.extend([
162+
detshuffle2(
163+
bufsize=_SHARD_SHUFFLE_SIZE,
164+
initial=_SHARD_SHUFFLE_INITIAL,
165+
seed=seed,
166+
epoch=self.shared_epoch,
167+
),
168+
wds.split_by_node,
169+
wds.split_by_worker,
170+
])
171+
pipeline.extend([
172+
# at this point, we have an iterator over the shards assigned to each worker at each node
173+
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
174+
wds.shuffle(
175+
bufsize=_SAMPLE_SHUFFLE_SIZE,
176+
initial=_SAMPLE_SHUFFLE_INITIAL,
177+
),
178+
])
179+
else:
180+
pipeline.extend([
181+
wds.split_by_worker,
182+
# at this point, we have an iterator over the shards assigned to each worker
183+
wds.tarfile_to_samples(handler=log_and_continue),
184+
])
185+
186+
# here we also load the key of data
187+
def json_parse_key(json_dict: Dict) -> int:
188+
return int(json_dict["key"])
189+
190+
if return_index:
191+
rename = wds.rename(image="jpg;png;jpeg;webp", text="txt", key="json")
192+
if tokenize is not None:
193+
map_dict = wds.map_dict(image=preprocess_img, text=tokenize, key=json_parse_key)
194+
else:
195+
map_dict = wds.map_dict(image=preprocess_img, key=json_parse_key)
196+
to_tuple = wds.to_tuple("image", "text", "key", "key")
197+
else:
198+
rename = wds.rename(image="jpg;png;jpeg;webp", text="txt")
199+
if tokenize is not None:
200+
map_dict = wds.map_dict(image=preprocess_img, text=tokenize)
201+
else:
202+
map_dict = wds.map_dict(image=preprocess_img)
203+
to_tuple = wds.to_tuple("image", "text")
204+
pipeline.extend([
205+
wds.select(filter_no_caption_or_no_image),
206+
wds.decode("pilrgb", handler=log_and_continue),
207+
rename, map_dict, to_tuple,
208+
wds.batched(batch_size, partial=not is_train)
209+
])
210+
211+
super().__init__(*pipeline)

0 commit comments

Comments
 (0)