|
| 1 | +import os |
| 2 | +import subprocess |
| 3 | +import time |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from logging import getLogger |
| 8 | + |
| 9 | +import torch |
| 10 | +import torchvision |
| 11 | + |
| 12 | +_GLOBAL_SEED = 0 |
| 13 | +logger = getLogger() |
| 14 | + |
| 15 | + |
| 16 | +def make_imagenet1k( |
| 17 | + transform, |
| 18 | + batch_size, |
| 19 | + collator=None, |
| 20 | + pin_mem=True, |
| 21 | + num_workers=8, |
| 22 | + world_size=1, |
| 23 | + rank=0, |
| 24 | + root_path=None, |
| 25 | + image_folder=None, |
| 26 | + training=True, |
| 27 | + copy_data=False, |
| 28 | + drop_last=True, |
| 29 | + subset_file=None, |
| 30 | +): |
| 31 | + dataset = ImageNet( |
| 32 | + root=root_path, |
| 33 | + image_folder=image_folder, |
| 34 | + transform=transform, |
| 35 | + train=training, |
| 36 | + copy_data=copy_data, |
| 37 | + index_targets=False, |
| 38 | + ) |
| 39 | + if subset_file is not None: |
| 40 | + dataset = ImageNetSubset(dataset, subset_file) |
| 41 | + logger.info("ImageNet dataset created") |
| 42 | + dist_sampler = torch.utils.data.distributed.DistributedSampler( |
| 43 | + dataset=dataset, num_replicas=world_size, rank=rank |
| 44 | + ) |
| 45 | + data_loader = torch.utils.data.DataLoader( |
| 46 | + dataset, |
| 47 | + collate_fn=collator, |
| 48 | + sampler=dist_sampler, |
| 49 | + batch_size=batch_size, |
| 50 | + drop_last=drop_last, |
| 51 | + pin_memory=pin_mem, |
| 52 | + num_workers=num_workers, |
| 53 | + persistent_workers=False, |
| 54 | + ) |
| 55 | + logger.info("ImageNet unsupervised data loader created") |
| 56 | + |
| 57 | + return dataset, data_loader, dist_sampler |
| 58 | + |
| 59 | + |
| 60 | +class ImageNet(torchvision.datasets.ImageFolder): |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + root, |
| 65 | + image_folder="imagenet_full_size/061417/", |
| 66 | + tar_file="imagenet_full_size-061417.tar.gz", |
| 67 | + transform=None, |
| 68 | + train=True, |
| 69 | + job_id=None, |
| 70 | + local_rank=None, |
| 71 | + copy_data=True, |
| 72 | + index_targets=False, |
| 73 | + ): |
| 74 | + """ |
| 75 | + ImageNet |
| 76 | +
|
| 77 | + Dataset wrapper (can copy data locally to machine) |
| 78 | +
|
| 79 | + :param root: root network directory for ImageNet data |
| 80 | + :param image_folder: path to images inside root network directory |
| 81 | + :param tar_file: zipped image_folder inside root network directory |
| 82 | + :param train: whether to load train data (or validation) |
| 83 | + :param job_id: scheduler job-id used to create dir on local machine |
| 84 | + :param copy_data: whether to copy data from network file locally |
| 85 | + :param index_targets: whether to index the id of each labeled image |
| 86 | + """ |
| 87 | + |
| 88 | + suffix = "train/" if train else "val/" |
| 89 | + data_path = None |
| 90 | + if copy_data: |
| 91 | + logger.info("copying data locally") |
| 92 | + data_path = copy_imgnt_locally( |
| 93 | + root=root, |
| 94 | + suffix=suffix, |
| 95 | + image_folder=image_folder, |
| 96 | + tar_file=tar_file, |
| 97 | + job_id=job_id, |
| 98 | + local_rank=local_rank, |
| 99 | + ) |
| 100 | + if (not copy_data) or (data_path is None): |
| 101 | + data_path = os.path.join(root, image_folder, suffix) |
| 102 | + logger.info(f"data-path {data_path}") |
| 103 | + |
| 104 | + super(ImageNet, self).__init__(root=data_path, transform=transform) |
| 105 | + logger.info("Initialized ImageNet") |
| 106 | + |
| 107 | + if index_targets: |
| 108 | + self.targets = [] |
| 109 | + for sample in self.samples: |
| 110 | + self.targets.append(sample[1]) |
| 111 | + self.targets = np.array(self.targets) |
| 112 | + self.samples = np.array(self.samples) |
| 113 | + |
| 114 | + mint = None |
| 115 | + self.target_indices = [] |
| 116 | + for t in range(len(self.classes)): |
| 117 | + indices = np.squeeze(np.argwhere(self.targets == t)).tolist() |
| 118 | + self.target_indices.append(indices) |
| 119 | + mint = len(indices) if mint is None else min(mint, len(indices)) |
| 120 | + logger.debug(f"num-labeled target {t} {len(indices)}") |
| 121 | + logger.info(f"min. labeled indices {mint}") |
| 122 | + |
| 123 | + |
| 124 | +class ImageNetSubset(object): |
| 125 | + |
| 126 | + def __init__(self, dataset, subset_file): |
| 127 | + """ |
| 128 | + ImageNetSubset |
| 129 | +
|
| 130 | + :param dataset: ImageNet dataset object |
| 131 | + :param subset_file: '.txt' file containing IDs of IN1K images to keep |
| 132 | + """ |
| 133 | + self.dataset = dataset |
| 134 | + self.subset_file = subset_file |
| 135 | + self.filter_dataset_(subset_file) |
| 136 | + |
| 137 | + def filter_dataset_(self, subset_file): |
| 138 | + """Filter self.dataset to a subset""" |
| 139 | + root = self.dataset.root |
| 140 | + class_to_idx = self.dataset.class_to_idx |
| 141 | + # -- update samples to subset of IN1k targets/samples |
| 142 | + new_samples = [] |
| 143 | + logger.info(f"Using {subset_file}") |
| 144 | + with open(subset_file, "r") as rfile: |
| 145 | + for line in rfile: |
| 146 | + class_name = line.split("_")[0] |
| 147 | + target = class_to_idx[class_name] |
| 148 | + img = line.split("\n")[0] |
| 149 | + new_samples.append((os.path.join(root, class_name, img), target)) |
| 150 | + self.samples = new_samples |
| 151 | + |
| 152 | + @property |
| 153 | + def classes(self): |
| 154 | + return self.dataset.classes |
| 155 | + |
| 156 | + def __len__(self): |
| 157 | + return len(self.samples) |
| 158 | + |
| 159 | + def __getitem__(self, index): |
| 160 | + path, target = self.samples[index] |
| 161 | + img = self.dataset.loader(path) |
| 162 | + if self.dataset.transform is not None: |
| 163 | + img = self.dataset.transform(img) |
| 164 | + if self.dataset.target_transform is not None: |
| 165 | + target = self.dataset.target_transform(target) |
| 166 | + return img, target |
| 167 | + |
| 168 | + |
| 169 | +def copy_imgnt_locally( |
| 170 | + root, |
| 171 | + suffix, |
| 172 | + image_folder="imagenet_full_size/061417/", |
| 173 | + tar_file="imagenet_full_size-061417.tar.gz", |
| 174 | + job_id=None, |
| 175 | + local_rank=None, |
| 176 | +): |
| 177 | + if job_id is None: |
| 178 | + try: |
| 179 | + job_id = os.environ["SLURM_JOBID"] |
| 180 | + except Exception: |
| 181 | + logger.info("No job-id, will load directly from network file") |
| 182 | + return None |
| 183 | + |
| 184 | + if local_rank is None: |
| 185 | + try: |
| 186 | + local_rank = int(os.environ["SLURM_LOCALID"]) |
| 187 | + except Exception: |
| 188 | + logger.info("No job-id, will load directly from network file") |
| 189 | + return None |
| 190 | + |
| 191 | + source_file = os.path.join(root, tar_file) |
| 192 | + target = f"/scratch/slurm_tmpdir/{job_id}/" |
| 193 | + target_file = os.path.join(target, tar_file) |
| 194 | + data_path = os.path.join(target, image_folder, suffix) |
| 195 | + logger.info(f"{source_file}\n{target}\n{target_file}\n{data_path}") |
| 196 | + |
| 197 | + tmp_sgnl_file = os.path.join(target, "copy_signal.txt") |
| 198 | + |
| 199 | + if not os.path.exists(data_path): |
| 200 | + if local_rank == 0: |
| 201 | + commands = [["tar", "-xf", source_file, "-C", target]] |
| 202 | + for cmnd in commands: |
| 203 | + start_time = time.time() |
| 204 | + logger.info(f"Executing {cmnd}") |
| 205 | + subprocess.run(cmnd) |
| 206 | + logger.info(f"Cmnd took {(time.time()-start_time)/60.} min.") |
| 207 | + with open(tmp_sgnl_file, "+w") as f: |
| 208 | + print("Done copying locally.", file=f) |
| 209 | + else: |
| 210 | + while not os.path.exists(tmp_sgnl_file): |
| 211 | + time.sleep(60) |
| 212 | + logger.info(f"{local_rank}: Checking {tmp_sgnl_file}") |
| 213 | + |
| 214 | + return data_path |
0 commit comments