Skip to content

Commit d2f23cf

Browse files
committed
merge imagenet fixes
2 parents b6f37ac + 056281e commit d2f23cf

File tree

7 files changed

+334
-46
lines changed

7 files changed

+334
-46
lines changed

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def _build_dataset(
110110
batch_size=ds_iter_batch_size,
111111
shuffle=not USE_PYTORCH_DDP and is_train,
112112
sampler=sampler,
113-
num_workers=4 if is_train else self.eval_num_workers,
113+
num_workers=2 * N_GPUS if is_train else self.eval_num_workers,
114114
pin_memory=True,
115115
drop_last=is_train,
116116
)
117-
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
118117
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
118+
dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE)
119119
return dataloader
120120

121121
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 109 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import contextlib
44
import functools
55
import itertools
6+
import json
67
import math
78
import os
89
import random
9-
from typing import Dict, Iterator, Optional, Tuple
10+
import time
11+
from pathlib import Path
12+
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
1013

1114
import numpy as np
1215
import torch
1316
import torch.distributed as dist
1417
import torch.nn.functional as F
1518
from torch.nn.parallel import DistributedDataParallel as DDP
1619
from torchvision import transforms
17-
from torchvision.datasets.folder import ImageFolder
20+
from torchvision.datasets.folder import (
21+
IMG_EXTENSIONS,
22+
ImageFolder,
23+
default_loader,
24+
)
1825

1926
import algoperf.random_utils as prng
2027
from algoperf import data_utils, param_utils, pytorch_utils, spec
@@ -28,6 +35,100 @@
2835
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
2936

3037

38+
class CachedImageFolder(ImageFolder):
39+
"""ImageFolder that caches the file listing to avoid repeated filesystem scans."""
40+
41+
def __init__(
42+
self,
43+
root: Union[str, Path],
44+
cache_file: Optional[Union[str, Path]] = None,
45+
transform: Optional[Callable] = None,
46+
target_transform: Optional[Callable] = None,
47+
loader: Callable[[str], Any] = default_loader,
48+
is_valid_file: Optional[Callable[[str], bool]] = None,
49+
allow_empty: bool = False,
50+
rebuild_cache: bool = False,
51+
cache_build_timeout_minutes: int = 30,
52+
):
53+
self.root = os.path.expanduser(root)
54+
self.transform = transform
55+
self.target_transform = target_transform
56+
self.loader = loader
57+
self.extensions = IMG_EXTENSIONS if is_valid_file is None else None
58+
59+
# Default cache location: .cache_index.json in the root directory
60+
if cache_file is None:
61+
cache_file = os.path.join(self.root, '.cache_index.json')
62+
self.cache_file = cache_file
63+
64+
is_distributed = dist.is_available() and dist.is_initialized()
65+
rank = dist.get_rank() if is_distributed else 0
66+
67+
cache_exists = os.path.exists(self.cache_file)
68+
needs_rebuild = rebuild_cache or not cache_exists
69+
70+
if needs_rebuild:
71+
# We only want one process to build the cache
72+
# and others to wait for it to finish.
73+
if rank == 0:
74+
self._build_and_save_cache(is_valid_file, allow_empty)
75+
if is_distributed:
76+
self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes)
77+
dist.barrier()
78+
79+
self._load_from_cache()
80+
81+
self.targets = [s[1] for s in self.samples]
82+
self.imgs = self.samples
83+
84+
def _wait_for_cache(self, timeout_minutes: int):
85+
"""Poll for cache file to exist."""
86+
timeout_seconds = timeout_minutes * 60
87+
poll_interval = 5
88+
elapsed = 0
89+
90+
while not os.path.exists(self.cache_file):
91+
if elapsed >= timeout_seconds:
92+
raise TimeoutError(
93+
f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}'
94+
)
95+
time.sleep(poll_interval)
96+
elapsed += poll_interval
97+
98+
def _load_from_cache(self):
99+
"""Load classes and samples from cache file."""
100+
with open(os.path.abspath(self.cache_file), 'r') as f:
101+
cache = json.load(f)
102+
self.classes = cache['classes']
103+
self.class_to_idx = cache['class_to_idx']
104+
# Convert relative paths back to absolute
105+
self.samples = [
106+
(os.path.join(self.root, rel_path), idx)
107+
for rel_path, idx in cache['samples']
108+
]
109+
110+
def _build_and_save_cache(self, is_valid_file, allow_empty):
111+
"""Scan filesystem, build index, and save to cache."""
112+
self.classes, self.class_to_idx = self.find_classes(self.root)
113+
self.samples = self.make_dataset(
114+
self.root,
115+
class_to_idx=self.class_to_idx,
116+
extensions=self.extensions,
117+
is_valid_file=is_valid_file,
118+
allow_empty=allow_empty,
119+
)
120+
121+
cache = {
122+
'classes': self.classes,
123+
'class_to_idx': self.class_to_idx,
124+
'samples': [
125+
(os.path.relpath(path, self.root), idx) for path, idx in self.samples
126+
],
127+
}
128+
with open(os.path.abspath(self.cache_file), 'w') as f:
129+
json.dump(cache, f)
130+
131+
31132
def imagenet_v2_to_torch(
32133
batch: Dict[str, spec.Tensor],
33134
) -> Dict[str, spec.Tensor]:
@@ -119,8 +220,10 @@ def _build_dataset(
119220
)
120221

121222
folder = 'train' if 'train' in split else 'val'
122-
dataset = ImageFolder(
123-
os.path.join(data_dir, folder), transform=transform_config
223+
dataset = CachedImageFolder(
224+
os.path.join(data_dir, folder),
225+
transform=transform_config,
226+
cache_file='.imagenet_cache_index.json',
124227
)
125228

126229
if split == 'eval_train':
@@ -151,10 +254,11 @@ def _build_dataset(
151254
batch_size=ds_iter_batch_size,
152255
shuffle=not USE_PYTORCH_DDP and is_train,
153256
sampler=sampler,
154-
num_workers=4 if is_train else self.eval_num_workers,
257+
num_workers=5 * N_GPUS if is_train else self.eval_num_workers,
155258
pin_memory=True,
156259
drop_last=is_train,
157260
persistent_workers=is_train,
261+
prefetch_factor=N_GPUS,
158262
)
159263
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
160264
dataloader = data_utils.cycle(
@@ -163,7 +267,6 @@ def _build_dataset(
163267
use_mixup=use_mixup,
164268
mixup_alpha=0.2,
165269
)
166-
167270
return dataloader
168271

169272
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
and https://github.com/lucidrains/vit-pytorch.
66
"""
77

8-
import math
98
from typing import Any, Optional, Tuple, Union
109

1110
import torch
@@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
126125
value_layer = self.transpose_for_scores(self.value(x))
127126
query_layer = self.transpose_for_scores(mixed_query_layer)
128127

129-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
130-
attention_scores = attention_scores / math.sqrt(self.head_dim)
131-
132-
attention_probs = F.softmax(attention_scores, dim=-1)
133-
attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
128+
# Use built-in scaled_dot_product_attention (Flash Attention when available)
129+
context_layer = F.scaled_dot_product_attention(
130+
query_layer,
131+
key_layer,
132+
value_layer,
133+
dropout_p=dropout_rate if self.training else 0.0,
134+
)
134135

135-
context_layer = torch.matmul(attention_probs, value_layer)
136136
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
137137
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,)
138138
context_layer = context_layer.view(new_context_layer_shape)

algorithms/baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,6 @@ def update_params(
340340
dropout_rate,
341341
)
342342
)
343-
344-
# Log loss, grad_norm.
345-
if global_step % 100 == 0 and workload.metrics_logger is not None:
346-
workload.metrics_logger.append_scalar_metrics(
347-
{'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step
348-
)
349343
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
350344

351345

algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.distributed.nn as dist_nn
8-
from absl import logging
98
from torch import Tensor
109
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
1110

@@ -300,28 +299,6 @@ def update_params(
300299
optimizer_state['optimizer'].step()
301300
optimizer_state['scheduler'].step()
302301

303-
# Log training metrics - loss, grad_norm, batch_size.
304-
if global_step <= 100 or global_step % 500 == 0:
305-
with torch.no_grad():
306-
parameters = [p for p in current_model.parameters() if p.grad is not None]
307-
grad_norm = torch.norm(
308-
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
309-
)
310-
if workload.metrics_logger is not None:
311-
workload.metrics_logger.append_scalar_metrics(
312-
{
313-
'loss': loss.item(),
314-
'grad_norm': grad_norm.item(),
315-
},
316-
global_step,
317-
)
318-
logging.info(
319-
'%d) loss = %0.3f, grad_norm = %0.3f',
320-
global_step,
321-
loss.item(),
322-
grad_norm.item(),
323-
)
324-
325302
return (optimizer_state, current_param_container, new_model_state)
326303

327304

submission_runner.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def train_once(
256256
'librispeech_conformer',
257257
'ogbg',
258258
'criteo1tb',
259-
'imagenet_vit',
260259
'librispeech_deepspeech',
261260
]
262261
eager_backend_workloads = []
@@ -267,6 +266,7 @@ def train_once(
267266
'ogbg',
268267
'wmt',
269268
'finewebedu_lm',
269+
'imagenet_vit',
270270
]
271271
base_workload = workloads.get_base_workload_name(workload_name)
272272
if base_workload in compile_error_workloads:
@@ -353,7 +353,7 @@ def train_once(
353353
log_dir, flags.FLAGS, hyperparameters
354354
)
355355
workload.attach_metrics_logger(metrics_logger)
356-
356+
step_10_end_time = None
357357
global_start_time = get_time()
358358
train_state['last_step_end_time'] = global_start_time
359359

@@ -410,6 +410,21 @@ def train_once(
410410
train_state['training_complete'] = True
411411

412412
train_step_end_time = get_time()
413+
if global_step == 11:
414+
step_10_end_time = train_step_end_time
415+
416+
# Log step time every 100 steps
417+
if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None:
418+
if step_10_end_time is not None and global_step > 11:
419+
elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0
420+
elapsed_steps = global_step - 11
421+
avg_step_time_ms = elapsed_time_ms / elapsed_steps
422+
else:
423+
avg_step_time_ms = 0.0
424+
workload.metrics_logger.append_scalar_metrics(
425+
{'step_time_ms': avg_step_time_ms},
426+
global_step - 1,
427+
)
413428

414429
train_state['accumulated_submission_time'] += (
415430
train_step_end_time - train_state['last_step_end_time']

0 commit comments

Comments
 (0)