Skip to content

Commit f6974eb

Browse files
committed
Change num_workers for imagenet, add validation tests for step times
1 parent 9c93fc2 commit f6974eb

File tree

7 files changed

+212
-289
lines changed

7 files changed

+212
-289
lines changed

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def _build_dataset(
110110
batch_size=ds_iter_batch_size,
111111
shuffle=not USE_PYTORCH_DDP and is_train,
112112
sampler=sampler,
113-
num_workers=4 if is_train else self.eval_num_workers,
113+
num_workers=2 * N_GPUS if is_train else self.eval_num_workers,
114114
pin_memory=True,
115115
drop_last=is_train,
116116
)
117-
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
118117
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
118+
dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE)
119119
return dataloader
120120

121121
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,11 @@ def _build_dataset(
254254
batch_size=ds_iter_batch_size,
255255
shuffle=not USE_PYTORCH_DDP and is_train,
256256
sampler=sampler,
257-
num_workers=4 if is_train else self.eval_num_workers,
257+
num_workers=5 * N_GPUS if is_train else self.eval_num_workers,
258258
pin_memory=True,
259259
drop_last=is_train,
260260
persistent_workers=is_train,
261+
prefetch_factor=N_GPUS,
261262
)
262263
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
263264
dataloader = data_utils.cycle(
@@ -266,7 +267,6 @@ def _build_dataset(
266267
use_mixup=use_mixup,
267268
mixup_alpha=0.2,
268269
)
269-
270270
return dataloader
271271

272272
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
and https://github.com/lucidrains/vit-pytorch.
66
"""
77

8-
import math
98
from typing import Any, Optional, Tuple, Union
109

1110
import torch
@@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
126125
value_layer = self.transpose_for_scores(self.value(x))
127126
query_layer = self.transpose_for_scores(mixed_query_layer)
128127

129-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
130-
attention_scores = attention_scores / math.sqrt(self.head_dim)
131-
132-
attention_probs = F.softmax(attention_scores, dim=-1)
133-
attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
128+
# Use built-in scaled_dot_product_attention (Flash Attention when available)
129+
context_layer = F.scaled_dot_product_attention(
130+
query_layer,
131+
key_layer,
132+
value_layer,
133+
dropout_p=dropout_rate if self.training else 0.0,
134+
)
134135

135-
context_layer = torch.matmul(attention_probs, value_layer)
136136
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
137137
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,)
138138
context_layer = context_layer.view(new_context_layer_shape)

algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.distributed.nn as dist_nn
8-
from absl import logging
98
from torch import Tensor
109
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
1110

benchmark_step_times.py

Lines changed: 0 additions & 274 deletions
This file was deleted.

submission_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def train_once(
256256
'librispeech_conformer',
257257
'ogbg',
258258
'criteo1tb',
259-
'imagenet_vit',
260259
'librispeech_deepspeech',
261260
]
262261
eager_backend_workloads = []
@@ -266,6 +265,7 @@ def train_once(
266265
'librispeech_deepspeech',
267266
'ogbg',
268267
'wmt',
268+
'imagenet_vit',
269269
]
270270
base_workload = workloads.get_base_workload_name(workload_name)
271271
if base_workload in compile_error_workloads:
@@ -411,9 +411,8 @@ def train_once(
411411
train_step_end_time = get_time()
412412
if global_step == 11:
413413
step_10_end_time = train_step_end_time
414-
414+
415415
# Log step time every 100 steps
416-
# Note: global_step was incremented, so use (global_step - 1) to match
417416
if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None:
418417
if step_10_end_time is not None and global_step > 11:
419418
elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0

0 commit comments

Comments
 (0)