Skip to content

Commit 3608624

Browse files
committed
add debug scripts
1 parent c2f443b commit 3608624

File tree

6 files changed

+738
-0
lines changed

6 files changed

+738
-0
lines changed

debug/benchmark_dataloader_jax.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Benchmark script for JAX ImageNet dataloader."""
2+
3+
import time
4+
5+
import jax
6+
import numpy as np
7+
import tensorflow_datasets as tfds
8+
9+
from algoperf.workloads.imagenet_resnet import input_pipeline
10+
11+
# ImageNet constants (same as workload)
12+
TRAIN_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255)
13+
TRAIN_STDDEV = (0.229 * 255, 0.224 * 255, 0.225 * 255)
14+
CENTER_CROP_SIZE = 224
15+
RESIZE_SIZE = 256
16+
ASPECT_RATIO_RANGE = (0.75, 4.0 / 3.0)
17+
SCALE_RATIO_RANGE = (0.08, 1.0)
18+
19+
20+
def main():
21+
data_dir = '/home/ak4605/algoperf-data/imagenet/jax'
22+
global_batch_size = 1024
23+
num_batches = 100
24+
25+
rng = jax.random.PRNGKey(0)
26+
ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir)
27+
28+
print(f'Creating JAX ImageNet dataloader...')
29+
print(f'Batch size: {global_batch_size}')
30+
print(f'Num devices: {jax.local_device_count()}')
31+
32+
ds = input_pipeline.create_split(
33+
split='train',
34+
dataset_builder=ds_builder,
35+
rng=rng,
36+
global_batch_size=global_batch_size,
37+
train=True,
38+
image_size=CENTER_CROP_SIZE,
39+
resize_size=RESIZE_SIZE,
40+
mean_rgb=TRAIN_MEAN,
41+
stddev_rgb=TRAIN_STDDEV,
42+
cache=False,
43+
repeat_final_dataset=True,
44+
aspect_ratio_range=ASPECT_RATIO_RANGE,
45+
area_range=SCALE_RATIO_RANGE,
46+
use_mixup=False,
47+
use_randaug=False,
48+
image_format='NHWC',
49+
)
50+
51+
ds_iter = iter(ds)
52+
53+
# Warmup
54+
print('Warming up...')
55+
for i in range(5):
56+
start = time.perf_counter()
57+
batch = next(ds_iter)
58+
end = time.perf_counter()
59+
print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms')
60+
61+
print(f"Batch 'inputs' shape: {batch['inputs'].shape}")
62+
63+
# Benchmark
64+
print(f'Benchmarking {num_batches} batches...')
65+
times = []
66+
for i in range(num_batches):
67+
start = time.perf_counter()
68+
batch = next(ds_iter)
69+
# Force sync by accessing data
70+
_ = np.asarray(batch['inputs'][0, 0, 0, 0])
71+
end = time.perf_counter()
72+
times.append(end - start)
73+
if (i + 1) % 20 == 0:
74+
print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms')
75+
76+
times = np.array(times)
77+
print(f'\n=== JAX DataLoader Results ===')
78+
print(f'Mean time per batch: {times.mean()*1000:.2f}ms')
79+
print(f'Std time per batch: {times.std()*1000:.2f}ms')
80+
print(f'Min time per batch: {times.min()*1000:.2f}ms')
81+
print(f'Max time per batch: {times.max()*1000:.2f}ms')
82+
print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec')
83+
84+
# Print machine-readable results for the fish script
85+
print(f'\n=== RESULTS ===')
86+
print(f'MEAN_MS={times.mean()*1000:.2f}')
87+
print(f'THROUGHPUT={global_batch_size / times.mean():.2f}')
88+
89+
90+
if __name__ == '__main__':
91+
main()
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Benchmark script for PyTorch ImageNet dataloader using shared TFDS pipeline."""
2+
3+
import time
4+
5+
import jax
6+
import numpy as np
7+
import tensorflow as tf
8+
tf.config.set_visible_devices([], 'GPU') # Disable TF GPU usage
9+
import tensorflow_datasets as tfds
10+
import torch
11+
import torch.distributed as dist
12+
13+
from algoperf import pytorch_utils
14+
from algoperf.workloads.imagenet_resnet import input_pipeline
15+
16+
# ImageNet constants (same as workload)
17+
TRAIN_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255)
18+
TRAIN_STDDEV = (0.229 * 255, 0.224 * 255, 0.225 * 255)
19+
CENTER_CROP_SIZE = 224
20+
RESIZE_SIZE = 256
21+
ASPECT_RATIO_RANGE = (0.75, 4.0 / 3.0)
22+
SCALE_RATIO_RANGE = (0.08, 1.0)
23+
24+
25+
def main():
26+
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
27+
28+
# Initialize DDP process group
29+
if USE_PYTORCH_DDP:
30+
torch.cuda.set_device(RANK)
31+
dist.init_process_group('nccl')
32+
33+
data_dir = '/home/ak4605/algoperf-data/imagenet/jax'
34+
global_batch_size = 1024
35+
num_batches = 100
36+
37+
if RANK == 0:
38+
print(f'Creating PyTorch ImageNet dataloader (shared TFDS pipeline)...')
39+
print(f'Batch size: {global_batch_size}')
40+
print(f'Num GPUs: {N_GPUS}')
41+
print(f'USE_PYTORCH_DDP: {USE_PYTORCH_DDP}')
42+
43+
# Calculate per-device batch size for DDP
44+
if USE_PYTORCH_DDP:
45+
batch_size = global_batch_size // N_GPUS
46+
else:
47+
batch_size = global_batch_size
48+
49+
if RANK == 0:
50+
print(f'Per-device batch size: {batch_size}')
51+
52+
rng = jax.random.PRNGKey(0)
53+
ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir)
54+
55+
ds = input_pipeline.create_split(
56+
split='train',
57+
dataset_builder=ds_builder,
58+
rng=rng,
59+
global_batch_size=batch_size,
60+
train=True,
61+
image_size=CENTER_CROP_SIZE,
62+
resize_size=RESIZE_SIZE,
63+
mean_rgb=TRAIN_MEAN,
64+
stddev_rgb=TRAIN_STDDEV,
65+
cache=False,
66+
repeat_final_dataset=True,
67+
aspect_ratio_range=ASPECT_RATIO_RANGE,
68+
area_range=SCALE_RATIO_RANGE,
69+
use_mixup=False,
70+
use_randaug=False,
71+
image_format='NCHW',
72+
threadpool_size=48 if USE_PYTORCH_DDP else 48,
73+
)
74+
75+
ds_iter = iter(ds)
76+
77+
def get_batch():
78+
batch = next(ds_iter)
79+
inputs = torch.from_numpy(batch['inputs'].numpy()).to(DEVICE)
80+
targets = torch.from_numpy(batch['targets'].numpy()).to(DEVICE, dtype=torch.long)
81+
return {'inputs': inputs, 'targets': targets}
82+
83+
# Warmup
84+
if RANK == 0:
85+
print('Warming up...')
86+
for i in range(5):
87+
start = time.perf_counter()
88+
batch = get_batch()
89+
end = time.perf_counter()
90+
if RANK == 0:
91+
print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms')
92+
93+
if RANK == 0:
94+
print(f"Batch 'inputs' shape: {batch['inputs'].shape}")
95+
96+
# Synchronize before benchmark
97+
if USE_PYTORCH_DDP:
98+
dist.barrier()
99+
100+
# Benchmark
101+
if RANK == 0:
102+
print(f'Benchmarking {num_batches} batches...')
103+
times = []
104+
for i in range(num_batches):
105+
if USE_PYTORCH_DDP:
106+
dist.barrier()
107+
start = time.perf_counter()
108+
batch = get_batch()
109+
end = time.perf_counter()
110+
times.append(end - start)
111+
if RANK == 0 and (i + 1) % 20 == 0:
112+
print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms')
113+
114+
times = np.array(times)
115+
if RANK == 0:
116+
print(f'\n=== PyTorch DataLoader Results ===')
117+
print(f'Mean time per batch: {times.mean()*1000:.2f}ms')
118+
print(f'Std time per batch: {times.std()*1000:.2f}ms')
119+
print(f'Min time per batch: {times.min()*1000:.2f}ms')
120+
print(f'Max time per batch: {times.max()*1000:.2f}ms')
121+
print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec')
122+
123+
# Print machine-readable results for the fish script
124+
print(f'\n=== RESULTS ===')
125+
print(f'MEAN_MS={times.mean()*1000:.2f}')
126+
print(f'THROUGHPUT={global_batch_size / times.mean():.2f}')
127+
128+
if USE_PYTORCH_DDP:
129+
dist.destroy_process_group()
130+
131+
132+
if __name__ == '__main__':
133+
main()

debug/benchmark_dataloaders.fish

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env fish
2+
3+
# Benchmark script to compare JAX vs PyTorch ImageNet dataloaders
4+
# Usage: ./benchmark_dataloaders.fish
5+
6+
set script_dir (dirname (status filename))
7+
set pytorch_output "$script_dir/benchmark_dataloader_pytorch.txt"
8+
set jax_output "$script_dir/benchmark_dataloader_jax.txt"
9+
10+
echo "============================================="
11+
echo "ImageNet DataLoader Benchmark"
12+
echo "============================================="
13+
echo ""
14+
15+
# Run PyTorch benchmark with DDP (4 processes)
16+
echo ">>> Running PyTorch DataLoader Benchmark (DDP with 4 GPUs)..."
17+
echo ">>> Activating conda environment: ap11_torch_latest"
18+
conda activate ap11_torch_latest
19+
20+
echo ">>> Output will be saved to: $pytorch_output"
21+
torchrun --nproc_per_node=4 --standalone benchmark_dataloader_pytorch.py 2>&1 | tee $pytorch_output
22+
set pytorch_status $status
23+
24+
if test $pytorch_status -ne 0
25+
echo "PyTorch benchmark failed with status $pytorch_status"
26+
end
27+
28+
echo ""
29+
30+
# Run JAX benchmark
31+
echo ">>> Running JAX DataLoader Benchmark..."
32+
echo ">>> Activating conda environment: ap11_jax"
33+
conda activate ap11_jax
34+
35+
echo ">>> Output will be saved to: $jax_output"
36+
python benchmark_dataloader_jax.py 2>&1 | tee $jax_output
37+
set jax_status $status
38+
39+
if test $jax_status -ne 0
40+
echo "JAX benchmark failed with status $jax_status"
41+
end
42+
43+
echo ""
44+
45+
# Extract results from output files
46+
function extract_result
47+
set file $argv[1]
48+
set key $argv[2]
49+
grep "^$key=" $file | sed "s/$key=//"
50+
end
51+
52+
# Parse PyTorch results
53+
set pt_mean_ms (extract_result $pytorch_output "MEAN_MS")
54+
set pt_throughput (extract_result $pytorch_output "THROUGHPUT")
55+
56+
# Parse JAX results
57+
set jax_mean_ms (extract_result $jax_output "MEAN_MS")
58+
set jax_throughput (extract_result $jax_output "THROUGHPUT")
59+
60+
echo "============================================="
61+
echo " RESULTS TABLE"
62+
echo "============================================="
63+
echo ""
64+
printf "%-25s %15s %15s\n" "" "PyTorch" "JAX"
65+
echo "-------------------------------------------------------------"
66+
printf "%-25s %12s ms %12s ms\n" "Mean Time per Batch" "$pt_mean_ms" "$jax_mean_ms"
67+
printf "%-25s %12s/s %12s/s\n" "Throughput" "$pt_throughput" "$jax_throughput"
68+
echo "-------------------------------------------------------------"
69+
echo ""
70+
echo "Note: Both use shared TFDS/TFRecords input pipeline"
71+
echo " Batch size: 1024 (global)"
72+
echo ""

0 commit comments

Comments
 (0)