|
| 1 | +# Adapted from https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py |
| 2 | + |
| 3 | +import logging |
| 4 | +import random |
| 5 | +from multiprocessing import Value |
| 6 | +from typing import Dict, Callable, Optional |
| 7 | + |
| 8 | +from torch.utils.data import get_worker_info |
| 9 | +try: |
| 10 | + import webdataset as wds |
| 11 | + from webdataset.filters import _shuffle |
| 12 | + from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample |
| 13 | +except ImportError: |
| 14 | + raise ImportError("webdataset is not installed. Please install it by running `pip install webdataset`.") |
| 15 | + |
| 16 | + |
| 17 | +class SharedEpoch: |
| 18 | + """Epoch number for distributed training""" |
| 19 | + def __init__(self, epoch: int = 0): |
| 20 | + self.shared_epoch = Value('i', epoch) |
| 21 | + |
| 22 | + def set_value(self, epoch): |
| 23 | + self.shared_epoch.value = epoch |
| 24 | + |
| 25 | + def get_value(self): |
| 26 | + return self.shared_epoch.value |
| 27 | + |
| 28 | + |
| 29 | +def filter_no_caption_or_no_image(sample): |
| 30 | + """Check if sample has caption and image""" |
| 31 | + has_caption = ('txt' in sample) |
| 32 | + has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) |
| 33 | + return has_caption and has_image |
| 34 | + |
| 35 | + |
| 36 | +def log_and_continue(exn): |
| 37 | + """Call in an exception handler to ignore any exception, issue a warning, and continue.""" |
| 38 | + logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') |
| 39 | + return True |
| 40 | + |
| 41 | + |
| 42 | +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): |
| 43 | + """Return function over iterator that groups key, value pairs into samples. |
| 44 | +
|
| 45 | + :param keys: function that splits the key into key and extension (base_plus_ext) |
| 46 | + :param lcase: convert suffixes to lower case (Default value = True) |
| 47 | + """ |
| 48 | + current_sample = None |
| 49 | + for filesample in data: |
| 50 | + assert isinstance(filesample, dict) |
| 51 | + fname, value = filesample["fname"], filesample["data"] |
| 52 | + prefix, suffix = keys(fname) |
| 53 | + if prefix is None: |
| 54 | + continue |
| 55 | + if lcase: |
| 56 | + suffix = suffix.lower() |
| 57 | + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for |
| 58 | + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next |
| 59 | + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset |
| 60 | + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: |
| 61 | + if valid_sample(current_sample): |
| 62 | + yield current_sample |
| 63 | + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) |
| 64 | + if suffixes is None or suffix in suffixes: |
| 65 | + current_sample[suffix] = value |
| 66 | + if valid_sample(current_sample): |
| 67 | + yield current_sample |
| 68 | + |
| 69 | + |
| 70 | +def tarfile_to_samples_nothrow(src, handler=log_and_continue): |
| 71 | + """A re-implementation of the webdataset impl with group_by_keys that doesn't throw""" |
| 72 | + streams = url_opener(src, handler=handler) |
| 73 | + files = tar_file_expander(streams, handler=handler) |
| 74 | + samples = group_by_keys_nothrow(files, handler=handler) |
| 75 | + return samples |
| 76 | + |
| 77 | + |
| 78 | +def pytorch_worker_seed(increment=0): |
| 79 | + """Get dataloader worker seed from pytorch""" |
| 80 | + worker_info = get_worker_info() |
| 81 | + if worker_info is not None: |
| 82 | + # favour using the seed already created for pytorch dataloader workers if it exists |
| 83 | + seed = worker_info.seed |
| 84 | + if increment: |
| 85 | + # space out seed increments so they can't overlap across workers in different iterations |
| 86 | + seed += increment * max(1, worker_info.num_workers) |
| 87 | + return seed |
| 88 | + # fallback to wds rank based seed |
| 89 | + return wds.utils.pytorch_worker_seed() |
| 90 | + |
| 91 | + |
| 92 | +_SHARD_SHUFFLE_SIZE = 2000 |
| 93 | +_SHARD_SHUFFLE_INITIAL = 500 |
| 94 | +_SAMPLE_SHUFFLE_SIZE = 5000 |
| 95 | +_SAMPLE_SHUFFLE_INITIAL = 1000 |
| 96 | + |
| 97 | + |
| 98 | +class detshuffle2(wds.PipelineStage): |
| 99 | + """Shuffle according to seed and epoch""" |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + bufsize=1000, |
| 103 | + initial=100, |
| 104 | + seed=0, |
| 105 | + epoch=-1, |
| 106 | + ): |
| 107 | + self.bufsize = bufsize |
| 108 | + self.initial = initial |
| 109 | + self.seed = seed |
| 110 | + self.epoch = epoch |
| 111 | + |
| 112 | + def run(self, src): |
| 113 | + if isinstance(self.epoch, SharedEpoch): |
| 114 | + epoch = self.epoch.get_value() |
| 115 | + else: |
| 116 | + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) |
| 117 | + # situation as different workers may wrap at different times (or not at all). |
| 118 | + self.epoch += 1 |
| 119 | + epoch = self.epoch |
| 120 | + rng = random.Random() |
| 121 | + if self.seed < 0: |
| 122 | + # If seed is negative, we use the worker's seed, this will be different across all nodes/workers |
| 123 | + seed = pytorch_worker_seed(epoch) |
| 124 | + else: |
| 125 | + # This seed to be deterministic AND the same across all nodes/workers in each epoch |
| 126 | + seed = self.seed + epoch |
| 127 | + rng.seed(seed) |
| 128 | + return _shuffle(src, self.bufsize, self.initial, rng) |
| 129 | + |
| 130 | + |
| 131 | +class WebDataset(wds.DataPipeline): |
| 132 | + r""" |
| 133 | + An image-text dataset that is stored in webdataset format. For more information on webdataset format, |
| 134 | + refer to https://github.com/webdataset/webdataset. |
| 135 | +
|
| 136 | + Args: |
| 137 | + input_shards (str): Path to the dataset shards. |
| 138 | + is_train (bool): Whether the dataset is for training or evaluation. |
| 139 | + batch_size (int): Batch size per worker. |
| 140 | + preprocess_img (Callable): Function to preprocess the image. |
| 141 | + seed (int): Seed for shuffling the dataset. |
| 142 | + epoch (int): Start epoch number. |
| 143 | + tokenize (Optional[Callable]): Tokenizer function for the text data. |
| 144 | + return_index (bool): Whether to return the index of the data. |
| 145 | + """ |
| 146 | + def __init__(self, |
| 147 | + input_shards: str, |
| 148 | + is_train: bool, |
| 149 | + batch_size: int, |
| 150 | + preprocess_img: Callable, |
| 151 | + seed: int = 0, |
| 152 | + epoch: int = 0, |
| 153 | + tokenize: Optional[Callable] = None, |
| 154 | + return_index: bool = False, |
| 155 | + ): |
| 156 | + self.shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc |
| 157 | + pipeline = [wds.SimpleShardList(input_shards)] |
| 158 | + |
| 159 | + # at this point we have an iterator over all the shards |
| 160 | + if is_train: |
| 161 | + pipeline.extend([ |
| 162 | + detshuffle2( |
| 163 | + bufsize=_SHARD_SHUFFLE_SIZE, |
| 164 | + initial=_SHARD_SHUFFLE_INITIAL, |
| 165 | + seed=seed, |
| 166 | + epoch=self.shared_epoch, |
| 167 | + ), |
| 168 | + wds.split_by_node, |
| 169 | + wds.split_by_worker, |
| 170 | + ]) |
| 171 | + pipeline.extend([ |
| 172 | + # at this point, we have an iterator over the shards assigned to each worker at each node |
| 173 | + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), |
| 174 | + wds.shuffle( |
| 175 | + bufsize=_SAMPLE_SHUFFLE_SIZE, |
| 176 | + initial=_SAMPLE_SHUFFLE_INITIAL, |
| 177 | + ), |
| 178 | + ]) |
| 179 | + else: |
| 180 | + pipeline.extend([ |
| 181 | + wds.split_by_worker, |
| 182 | + # at this point, we have an iterator over the shards assigned to each worker |
| 183 | + wds.tarfile_to_samples(handler=log_and_continue), |
| 184 | + ]) |
| 185 | + |
| 186 | + # here we also load the key of data |
| 187 | + def json_parse_key(json_dict: Dict) -> int: |
| 188 | + return int(json_dict["key"]) |
| 189 | + |
| 190 | + if return_index: |
| 191 | + rename = wds.rename(image="jpg;png;jpeg;webp", text="txt", key="json") |
| 192 | + if tokenize is not None: |
| 193 | + map_dict = wds.map_dict(image=preprocess_img, text=tokenize, key=json_parse_key) |
| 194 | + else: |
| 195 | + map_dict = wds.map_dict(image=preprocess_img, key=json_parse_key) |
| 196 | + to_tuple = wds.to_tuple("image", "text", "key", "key") |
| 197 | + else: |
| 198 | + rename = wds.rename(image="jpg;png;jpeg;webp", text="txt") |
| 199 | + if tokenize is not None: |
| 200 | + map_dict = wds.map_dict(image=preprocess_img, text=tokenize) |
| 201 | + else: |
| 202 | + map_dict = wds.map_dict(image=preprocess_img) |
| 203 | + to_tuple = wds.to_tuple("image", "text") |
| 204 | + pipeline.extend([ |
| 205 | + wds.select(filter_no_caption_or_no_image), |
| 206 | + wds.decode("pilrgb", handler=log_and_continue), |
| 207 | + rename, map_dict, to_tuple, |
| 208 | + wds.batched(batch_size, partial=not is_train) |
| 209 | + ]) |
| 210 | + |
| 211 | + super().__init__(*pipeline) |
0 commit comments