Skip to content

Commit 2f865a1

Browse files
committed
ImageNet caching for faster dataset access PyTorch
1 parent c9899cf commit 2f865a1

File tree

3 files changed

+109
-6
lines changed

3 files changed

+109
-6
lines changed

algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import contextlib
44
import functools
55
import itertools
6+
import json
67
import math
78
import os
89
import random
9-
from typing import Dict, Iterator, Optional, Tuple
10+
import time
11+
from pathlib import Path
12+
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
1013

1114
import numpy as np
1215
import torch
1316
import torch.distributed as dist
1417
import torch.nn.functional as F
1518
from torch.nn.parallel import DistributedDataParallel as DDP
1619
from torchvision import transforms
17-
from torchvision.datasets.folder import ImageFolder
20+
from torchvision.datasets.folder import (
21+
IMG_EXTENSIONS,
22+
ImageFolder,
23+
default_loader,
24+
)
1825

1926
import algoperf.random_utils as prng
2027
from algoperf import data_utils, param_utils, pytorch_utils, spec
@@ -28,6 +35,100 @@
2835
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
2936

3037

38+
class CachedImageFolder(ImageFolder):
39+
"""ImageFolder that caches the file listing to avoid repeated filesystem scans."""
40+
41+
def __init__(
42+
self,
43+
root: Union[str, Path],
44+
cache_file: Optional[Union[str, Path]] = None,
45+
transform: Optional[Callable] = None,
46+
target_transform: Optional[Callable] = None,
47+
loader: Callable[[str], Any] = default_loader,
48+
is_valid_file: Optional[Callable[[str], bool]] = None,
49+
allow_empty: bool = False,
50+
rebuild_cache: bool = False,
51+
cache_build_timeout_minutes: int = 30,
52+
):
53+
self.root = os.path.expanduser(root)
54+
self.transform = transform
55+
self.target_transform = target_transform
56+
self.loader = loader
57+
self.extensions = IMG_EXTENSIONS if is_valid_file is None else None
58+
59+
# Default cache location: .cache_index.json in the root directory
60+
if cache_file is None:
61+
cache_file = os.path.join(self.root, '.cache_index.json')
62+
self.cache_file = cache_file
63+
64+
is_distributed = dist.is_available() and dist.is_initialized()
65+
rank = dist.get_rank() if is_distributed else 0
66+
67+
cache_exists = os.path.exists(self.cache_file)
68+
needs_rebuild = rebuild_cache or not cache_exists
69+
70+
if needs_rebuild:
71+
# We only want one process to build the cache
72+
# and others to wait for it to finish.
73+
if rank == 0:
74+
self._build_and_save_cache(is_valid_file, allow_empty)
75+
if is_distributed:
76+
self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes)
77+
dist.barrier()
78+
79+
self._load_from_cache()
80+
81+
self.targets = [s[1] for s in self.samples]
82+
self.imgs = self.samples
83+
84+
def _wait_for_cache(self, timeout_minutes: int):
85+
"""Poll for cache file to exist."""
86+
timeout_seconds = timeout_minutes * 60
87+
poll_interval = 5
88+
elapsed = 0
89+
90+
while not os.path.exists(self.cache_file):
91+
if elapsed >= timeout_seconds:
92+
raise TimeoutError(
93+
f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}'
94+
)
95+
time.sleep(poll_interval)
96+
elapsed += poll_interval
97+
98+
def _load_from_cache(self):
99+
"""Load classes and samples from cache file."""
100+
with open(os.path.abspath(self.cache_file), 'r') as f:
101+
cache = json.load(f)
102+
self.classes = cache['classes']
103+
self.class_to_idx = cache['class_to_idx']
104+
# Convert relative paths back to absolute
105+
self.samples = [
106+
(os.path.join(self.root, rel_path), idx)
107+
for rel_path, idx in cache['samples']
108+
]
109+
110+
def _build_and_save_cache(self, is_valid_file, allow_empty):
111+
"""Scan filesystem, build index, and save to cache."""
112+
self.classes, self.class_to_idx = self.find_classes(self.root)
113+
self.samples = self.make_dataset(
114+
self.root,
115+
class_to_idx=self.class_to_idx,
116+
extensions=self.extensions,
117+
is_valid_file=is_valid_file,
118+
allow_empty=allow_empty,
119+
)
120+
121+
cache = {
122+
'classes': self.classes,
123+
'class_to_idx': self.class_to_idx,
124+
'samples': [
125+
(os.path.relpath(path, self.root), idx) for path, idx in self.samples
126+
],
127+
}
128+
with open(os.path.abspath(self.cache_file), 'w') as f:
129+
json.dump(cache, f)
130+
131+
31132
def imagenet_v2_to_torch(
32133
batch: Dict[str, spec.Tensor],
33134
) -> Dict[str, spec.Tensor]:
@@ -119,8 +220,10 @@ def _build_dataset(
119220
)
120221

121222
folder = 'train' if 'train' in split else 'val'
122-
dataset = ImageFolder(
123-
os.path.join(data_dir, folder), transform=transform_config
223+
dataset = CachedImageFolder(
224+
os.path.join(data_dir, folder),
225+
transform=transform_config,
226+
cache_file='.imagenet_cache_index.json',
124227
)
125228

126229
if split == 'eval_train':

algoperf/workloads/ogbg/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int:
9292

9393
@property
9494
def eval_period_time_sec(self) -> int:
95-
return 452 # approx 25 evals
95+
return 452 # approx 25 evals
9696

9797
def _build_input_queue(
9898
self,

scoring/performance_profile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
'wer',
7272
'l1_loss',
7373
'loss',
74-
'ppl'
74+
'ppl',
7575
]
7676

7777
MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu']

0 commit comments

Comments
 (0)