|
3 | 3 | import contextlib |
4 | 4 | import functools |
5 | 5 | import itertools |
| 6 | +import json |
6 | 7 | import math |
7 | 8 | import os |
8 | 9 | import random |
9 | | -from typing import Dict, Iterator, Optional, Tuple |
| 10 | +import time |
| 11 | +from pathlib import Path |
| 12 | +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union |
10 | 13 |
|
11 | 14 | import numpy as np |
12 | 15 | import torch |
13 | 16 | import torch.distributed as dist |
14 | 17 | import torch.nn.functional as F |
15 | 18 | from torch.nn.parallel import DistributedDataParallel as DDP |
16 | 19 | from torchvision import transforms |
17 | | -from torchvision.datasets.folder import ImageFolder |
| 20 | +from torchvision.datasets.folder import ( |
| 21 | + IMG_EXTENSIONS, |
| 22 | + ImageFolder, |
| 23 | + default_loader, |
| 24 | +) |
18 | 25 |
|
19 | 26 | import algoperf.random_utils as prng |
20 | 27 | from algoperf import data_utils, param_utils, pytorch_utils, spec |
|
28 | 35 | USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() |
29 | 36 |
|
30 | 37 |
|
| 38 | +class CachedImageFolder(ImageFolder): |
| 39 | + """ImageFolder that caches the file listing to avoid repeated filesystem scans.""" |
| 40 | + |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + root: Union[str, Path], |
| 44 | + cache_file: Optional[Union[str, Path]] = None, |
| 45 | + transform: Optional[Callable] = None, |
| 46 | + target_transform: Optional[Callable] = None, |
| 47 | + loader: Callable[[str], Any] = default_loader, |
| 48 | + is_valid_file: Optional[Callable[[str], bool]] = None, |
| 49 | + allow_empty: bool = False, |
| 50 | + rebuild_cache: bool = False, |
| 51 | + cache_build_timeout_minutes: int = 30, |
| 52 | + ): |
| 53 | + self.root = os.path.expanduser(root) |
| 54 | + self.transform = transform |
| 55 | + self.target_transform = target_transform |
| 56 | + self.loader = loader |
| 57 | + self.extensions = IMG_EXTENSIONS if is_valid_file is None else None |
| 58 | + |
| 59 | + # Default cache location: .cache_index.json in the root directory |
| 60 | + if cache_file is None: |
| 61 | + cache_file = os.path.join(self.root, '.cache_index.json') |
| 62 | + self.cache_file = cache_file |
| 63 | + |
| 64 | + is_distributed = dist.is_available() and dist.is_initialized() |
| 65 | + rank = dist.get_rank() if is_distributed else 0 |
| 66 | + |
| 67 | + cache_exists = os.path.exists(self.cache_file) |
| 68 | + needs_rebuild = rebuild_cache or not cache_exists |
| 69 | + |
| 70 | + if needs_rebuild: |
| 71 | + # We only want one process to build the cache |
| 72 | + # and others to wait for it to finish. |
| 73 | + if rank == 0: |
| 74 | + self._build_and_save_cache(is_valid_file, allow_empty) |
| 75 | + if is_distributed: |
| 76 | + self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes) |
| 77 | + dist.barrier() |
| 78 | + |
| 79 | + self._load_from_cache() |
| 80 | + |
| 81 | + self.targets = [s[1] for s in self.samples] |
| 82 | + self.imgs = self.samples |
| 83 | + |
| 84 | + def _wait_for_cache(self, timeout_minutes: int): |
| 85 | + """Poll for cache file to exist.""" |
| 86 | + timeout_seconds = timeout_minutes * 60 |
| 87 | + poll_interval = 5 |
| 88 | + elapsed = 0 |
| 89 | + |
| 90 | + while not os.path.exists(self.cache_file): |
| 91 | + if elapsed >= timeout_seconds: |
| 92 | + raise TimeoutError( |
| 93 | + f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}' |
| 94 | + ) |
| 95 | + time.sleep(poll_interval) |
| 96 | + elapsed += poll_interval |
| 97 | + |
| 98 | + def _load_from_cache(self): |
| 99 | + """Load classes and samples from cache file.""" |
| 100 | + with open(os.path.abspath(self.cache_file), 'r') as f: |
| 101 | + cache = json.load(f) |
| 102 | + self.classes = cache['classes'] |
| 103 | + self.class_to_idx = cache['class_to_idx'] |
| 104 | + # Convert relative paths back to absolute |
| 105 | + self.samples = [ |
| 106 | + (os.path.join(self.root, rel_path), idx) |
| 107 | + for rel_path, idx in cache['samples'] |
| 108 | + ] |
| 109 | + |
| 110 | + def _build_and_save_cache(self, is_valid_file, allow_empty): |
| 111 | + """Scan filesystem, build index, and save to cache.""" |
| 112 | + self.classes, self.class_to_idx = self.find_classes(self.root) |
| 113 | + self.samples = self.make_dataset( |
| 114 | + self.root, |
| 115 | + class_to_idx=self.class_to_idx, |
| 116 | + extensions=self.extensions, |
| 117 | + is_valid_file=is_valid_file, |
| 118 | + allow_empty=allow_empty, |
| 119 | + ) |
| 120 | + |
| 121 | + cache = { |
| 122 | + 'classes': self.classes, |
| 123 | + 'class_to_idx': self.class_to_idx, |
| 124 | + 'samples': [ |
| 125 | + (os.path.relpath(path, self.root), idx) for path, idx in self.samples |
| 126 | + ], |
| 127 | + } |
| 128 | + with open(os.path.abspath(self.cache_file), 'w') as f: |
| 129 | + json.dump(cache, f) |
| 130 | + |
| 131 | + |
31 | 132 | def imagenet_v2_to_torch( |
32 | 133 | batch: Dict[str, spec.Tensor], |
33 | 134 | ) -> Dict[str, spec.Tensor]: |
@@ -119,8 +220,10 @@ def _build_dataset( |
119 | 220 | ) |
120 | 221 |
|
121 | 222 | folder = 'train' if 'train' in split else 'val' |
122 | | - dataset = ImageFolder( |
123 | | - os.path.join(data_dir, folder), transform=transform_config |
| 223 | + dataset = CachedImageFolder( |
| 224 | + os.path.join(data_dir, folder), |
| 225 | + transform=transform_config, |
| 226 | + cache_file='.imagenet_cache_index.json', |
124 | 227 | ) |
125 | 228 |
|
126 | 229 | if split == 'eval_train': |
|
0 commit comments