Skip to content

Commit cbeb594

Browse files
committed
Move datasets/ to algoperf/datasets (otherwise it gives off an error because we now import the datasets library for the lm workload)
1 parent 3608624 commit cbeb594

File tree

13 files changed

+143
-93
lines changed

13 files changed

+143
-93
lines changed
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ This document provides instructions on downloading and preparing all datasets ut
2424
*TL;DR to download and prepare a dataset, run `dataset_setup.py`:*
2525

2626
```bash
27-
python3 datasets/dataset_setup.py \
27+
python3 algoperf/datasets/dataset_setup.py \
2828
--data_dir=~/data \
2929
--<dataset_name>
3030
--<optional_flags>
@@ -88,7 +88,7 @@ By default, a user will be prompted before any files are deleted. If you do not
8888
From `algorithmic-efficiency` run:
8989

9090
```bash
91-
python3 datasets/dataset_setup.py \
91+
python3 algoperf/datasets/dataset_setup.py \
9292
--data_dir $DATA_DIR \
9393
--ogbg
9494
```
@@ -124,7 +124,7 @@ In total, it should contain 13 files (via `find -type f | wc -l`) for a total of
124124
From `algorithmic-efficiency` run:
125125

126126
```bash
127-
python3 datasets/dataset_setup.py \
127+
python3 algoperf/datasets/dataset_setup.py \
128128
--data_dir $DATA_DIR \
129129
--wmt
130130
```
@@ -194,7 +194,7 @@ you should get an email containing the URLS for "knee_singlecoil_train",
194194
"knee_singlecoil_val" and "knee_singlecoil_test".
195195

196196
```bash
197-
python3 datasets/dataset_setup.py \
197+
python3 algoperf/datasets/dataset_setup.py \
198198
--data_dir $DATA_DIR \
199199
--fastmri \
200200
--fastmri_knee_singlecoil_train_url '<knee_singlecoil_train_url>' \
@@ -235,7 +235,7 @@ The ImageNet data pipeline differs between the PyTorch and JAX workloads.
235235
Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag.
236236

237237
```bash
238-
python3 datasets/dataset_setup.py \
238+
python3 algoperf/datasets/dataset_setup.py \
239239
--data_dir $DATA_DIR \
240240
--imagenet \
241241
--temp_dir $DATA_DIR/tmp \
@@ -349,7 +349,7 @@ In total, it should contain 20 files (via `find -type f | wc -l`) for a total of
349349
### Criteo1TB
350350

351351
```bash
352-
python3 datasets/dataset_setup.py \
352+
python3 algoperf/datasets/dataset_setup.py \
353353
--data_dir $DATA_DIR \
354354
--temp_dir $DATA_DIR/tmp \
355355
--criteo1tb
@@ -378,7 +378,7 @@ In total, it should contain 885 files (via `find -type f | wc -l`) for a total o
378378
To download, train a tokenizer and preprocess the librispeech dataset:
379379

380380
```bash
381-
python3 datasets/dataset_setup.py \
381+
python3 algoperf/datasets/dataset_setup.py \
382382
--data_dir $DATA_DIR \
383383
--temp_dir $DATA_DIR/tmp \
384384
--librispeech
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
5757
Example command:
5858
59-
python3 datasets/dataset_setup.py \
59+
python3 algoperf/datasets/dataset_setup.py \
6060
--data_dir=~/data \
6161
--temp_dir=/tmp/mlcommons_data
6262
--imagenet \
@@ -73,8 +73,8 @@
7373

7474
from algoperf.workloads.wmt import tokenizer
7575
from algoperf.workloads.wmt.input_pipeline import normalize_feature_names
76-
from datasets import librispeech_preprocess
77-
from datasets import librispeech_tokenizer
76+
from algoperf.datasets import librispeech_preprocess
77+
from algoperf.datasets import librispeech_tokenizer
7878

7979
import functools
8080
import os

datasets/librispeech_preprocess.py renamed to algoperf/datasets/librispeech_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from absl import logging
1515
from pydub import AudioSegment
1616

17-
from datasets import librispeech_tokenizer
17+
from algoperf.datasets import librispeech_tokenizer
1818

1919
gfile = tf.io.gfile
2020
copy = tf.io.gfile.copy
File renamed without changes.

algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def _build_dataset(
9494
batch_size = global_batch_size // N_GPUS
9595
else:
9696
batch_size = global_batch_size
97-
9897

9998
ds = input_pipeline.create_split(
10099
split,
@@ -107,7 +106,9 @@ def _build_dataset(
107106
mean_rgb=self.train_mean,
108107
stddev_rgb=self.train_stddev,
109108
cache=not train if cache is None else cache,
110-
repeat_final_dataset=repeat_final_dataset if repeat_final_dataset is not None else train,
109+
repeat_final_dataset=repeat_final_dataset
110+
if repeat_final_dataset is not None
111+
else train,
111112
aspect_ratio_range=self.aspect_ratio_range,
112113
area_range=self.scale_ratio_range,
113114
use_mixup=use_mixup,

algoperf/workloads/imagenet_resnet/input_pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,13 @@ def transpose_batch(batch):
396396
batch['inputs'] = tf.transpose(batch['inputs'], [0, 3, 1, 2])
397397
return batch
398398

399-
ds = ds.map(transpose_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE)
399+
ds = ds.map(
400+
transpose_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE
401+
)
400402
elif image_format != 'NHWC':
401-
raise ValueError(f"image_format must be 'NHWC' or 'NCHW', got {image_format}")
403+
raise ValueError(
404+
f"image_format must be 'NHWC' or 'NCHW', got {image_format}"
405+
)
402406

403407
ds = ds.prefetch(10)
404408

algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.distributed.nn as dist_nn
8-
from absl import logging
98
from torch import Tensor
109
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
1110

@@ -300,8 +299,7 @@ def update_params(
300299
optimizer_state['optimizer'].step()
301300
optimizer_state['scheduler'].step()
302301

303-
# Log training metrics - loss, grad_norm, batch_size.
304-
302+
# Log training metrics - loss, grad_norm.
305303
if global_step % 100 == 0 and workload.metrics_logger is not None:
306304
with torch.no_grad():
307305
parameters = [p for p in current_model.parameters() if p.grad is not None]

debug/benchmark_dataloader_jax.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919

2020
def main():
21-
data_dir = '/home/ak4605/algoperf-data/imagenet/jax'
21+
data_dir = '/home/ak4605/data/imagenet/jax'
2222
global_batch_size = 1024
2323
num_batches = 100
2424

2525
rng = jax.random.PRNGKey(0)
2626
ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir)
2727

28-
print(f'Creating JAX ImageNet dataloader...')
28+
print('Creating JAX ImageNet dataloader...')
2929
print(f'Batch size: {global_batch_size}')
3030
print(f'Num devices: {jax.local_device_count()}')
3131

@@ -56,7 +56,7 @@ def main():
5656
start = time.perf_counter()
5757
batch = next(ds_iter)
5858
end = time.perf_counter()
59-
print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms')
59+
print(f' Warmup batch {i + 1}/5: {(end - start) * 1000:.2f}ms')
6060

6161
print(f"Batch 'inputs' shape: {batch['inputs'].shape}")
6262

@@ -71,19 +71,19 @@ def main():
7171
end = time.perf_counter()
7272
times.append(end - start)
7373
if (i + 1) % 20 == 0:
74-
print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms')
74+
print(f' Batch {i + 1}/{num_batches}: {times[-1] * 1000:.2f}ms')
7575

7676
times = np.array(times)
77-
print(f'\n=== JAX DataLoader Results ===')
78-
print(f'Mean time per batch: {times.mean()*1000:.2f}ms')
79-
print(f'Std time per batch: {times.std()*1000:.2f}ms')
80-
print(f'Min time per batch: {times.min()*1000:.2f}ms')
81-
print(f'Max time per batch: {times.max()*1000:.2f}ms')
77+
print('\n=== JAX DataLoader Results ===')
78+
print(f'Mean time per batch: {times.mean() * 1000:.2f}ms')
79+
print(f'Std time per batch: {times.std() * 1000:.2f}ms')
80+
print(f'Min time per batch: {times.min() * 1000:.2f}ms')
81+
print(f'Max time per batch: {times.max() * 1000:.2f}ms')
8282
print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec')
8383

8484
# Print machine-readable results for the fish script
85-
print(f'\n=== RESULTS ===')
86-
print(f'MEAN_MS={times.mean()*1000:.2f}')
85+
print('\n=== RESULTS ===')
86+
print(f'MEAN_MS={times.mean() * 1000:.2f}')
8787
print(f'THROUGHPUT={global_batch_size / times.mean():.2f}')
8888

8989

debug/benchmark_dataloader_pytorch.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import jax
66
import numpy as np
77
import tensorflow as tf
8+
89
tf.config.set_visible_devices([], 'GPU') # Disable TF GPU usage
9-
import tensorflow_datasets as tfds
10-
import torch
11-
import torch.distributed as dist
10+
import tensorflow_datasets as tfds # noqa: E402
11+
import torch # noqa: E402
12+
import torch.distributed as dist # noqa: E402
1213

13-
from algoperf import pytorch_utils
14-
from algoperf.workloads.imagenet_resnet import input_pipeline
14+
from algoperf import pytorch_utils # noqa: E402
15+
from algoperf.workloads.imagenet_resnet import input_pipeline # noqa: E402
1516

1617
# ImageNet constants (same as workload)
1718
TRAIN_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255)
@@ -30,12 +31,12 @@ def main():
3031
torch.cuda.set_device(RANK)
3132
dist.init_process_group('nccl')
3233

33-
data_dir = '/home/ak4605/algoperf-data/imagenet/jax'
34+
data_dir = '/home/ak4605/data/imagenet/jax'
3435
global_batch_size = 1024
3536
num_batches = 100
3637

3738
if RANK == 0:
38-
print(f'Creating PyTorch ImageNet dataloader (shared TFDS pipeline)...')
39+
print('Creating PyTorch ImageNet dataloader (shared TFDS pipeline)...')
3940
print(f'Batch size: {global_batch_size}')
4041
print(f'Num GPUs: {N_GPUS}')
4142
print(f'USE_PYTORCH_DDP: {USE_PYTORCH_DDP}')
@@ -77,7 +78,9 @@ def main():
7778
def get_batch():
7879
batch = next(ds_iter)
7980
inputs = torch.from_numpy(batch['inputs'].numpy()).to(DEVICE)
80-
targets = torch.from_numpy(batch['targets'].numpy()).to(DEVICE, dtype=torch.long)
81+
targets = torch.from_numpy(batch['targets'].numpy()).to(
82+
DEVICE, dtype=torch.long
83+
)
8184
return {'inputs': inputs, 'targets': targets}
8285

8386
# Warmup
@@ -88,7 +91,7 @@ def get_batch():
8891
batch = get_batch()
8992
end = time.perf_counter()
9093
if RANK == 0:
91-
print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms')
94+
print(f' Warmup batch {i + 1}/5: {(end - start) * 1000:.2f}ms')
9295

9396
if RANK == 0:
9497
print(f"Batch 'inputs' shape: {batch['inputs'].shape}")
@@ -109,20 +112,20 @@ def get_batch():
109112
end = time.perf_counter()
110113
times.append(end - start)
111114
if RANK == 0 and (i + 1) % 20 == 0:
112-
print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms')
115+
print(f' Batch {i + 1}/{num_batches}: {times[-1] * 1000:.2f}ms')
113116

114117
times = np.array(times)
115118
if RANK == 0:
116-
print(f'\n=== PyTorch DataLoader Results ===')
117-
print(f'Mean time per batch: {times.mean()*1000:.2f}ms')
118-
print(f'Std time per batch: {times.std()*1000:.2f}ms')
119-
print(f'Min time per batch: {times.min()*1000:.2f}ms')
120-
print(f'Max time per batch: {times.max()*1000:.2f}ms')
119+
print('\n=== PyTorch DataLoader Results ===')
120+
print(f'Mean time per batch: {times.mean() * 1000:.2f}ms')
121+
print(f'Std time per batch: {times.std() * 1000:.2f}ms')
122+
print(f'Min time per batch: {times.min() * 1000:.2f}ms')
123+
print(f'Max time per batch: {times.max() * 1000:.2f}ms')
121124
print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec')
122125

123126
# Print machine-readable results for the fish script
124-
print(f'\n=== RESULTS ===')
125-
print(f'MEAN_MS={times.mean()*1000:.2f}')
127+
print('\n=== RESULTS ===')
128+
print(f'MEAN_MS={times.mean() * 1000:.2f}')
126129
print(f'THROUGHPUT={global_batch_size / times.mean():.2f}')
127130

128131
if USE_PYTORCH_DDP:

0 commit comments

Comments
 (0)