Skip to content

Commit 9c93fc2

Browse files
committed
some benchmarking steps
1 parent 2f865a1 commit 9c93fc2

File tree

4 files changed

+291
-29
lines changed

4 files changed

+291
-29
lines changed

algorithms/baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,6 @@ def update_params(
340340
dropout_rate,
341341
)
342342
)
343-
344-
# Log loss, grad_norm.
345-
if global_step % 100 == 0 and workload.metrics_logger is not None:
346-
workload.metrics_logger.append_scalar_metrics(
347-
{'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step
348-
)
349343
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
350344

351345

algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -300,28 +300,6 @@ def update_params(
300300
optimizer_state['optimizer'].step()
301301
optimizer_state['scheduler'].step()
302302

303-
# Log training metrics - loss, grad_norm, batch_size.
304-
if global_step <= 100 or global_step % 500 == 0:
305-
with torch.no_grad():
306-
parameters = [p for p in current_model.parameters() if p.grad is not None]
307-
grad_norm = torch.norm(
308-
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
309-
)
310-
if workload.metrics_logger is not None:
311-
workload.metrics_logger.append_scalar_metrics(
312-
{
313-
'loss': loss.item(),
314-
'grad_norm': grad_norm.item(),
315-
},
316-
global_step,
317-
)
318-
logging.info(
319-
'%d) loss = %0.3f, grad_norm = %0.3f',
320-
global_step,
321-
loss.item(),
322-
grad_norm.item(),
323-
)
324-
325303
return (optimizer_state, current_param_container, new_model_state)
326304

327305

benchmark_step_times.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
#!/usr/bin/env python3
2+
"""Benchmark step times for JAX and PyTorch across all workloads.
3+
4+
This script runs each workload for 101 steps with both JAX and PyTorch,
5+
captures the step_time_ms metric, and produces a comparison table.
6+
"""
7+
8+
import argparse
9+
import re
10+
import subprocess
11+
from pathlib import Path
12+
13+
# Base workloads to benchmark
14+
WORKLOADS = [
15+
'imagenet_resnet',
16+
]
17+
18+
FRAMEWORKS = ['jax', 'pytorch']
19+
MAX_STEPS = 201
20+
OUTPUT_DIR = Path('/home/ak4605/aef2/benchmark_outputs')
21+
22+
23+
def get_data_dir(workload: str, framework: str) -> str:
24+
"""Map workload to its data directory."""
25+
if workload in ['imagenet_resnet', 'imagenet_vit']:
26+
return '/opt/data/imagenet/' + framework
27+
elif workload in ['librispeech_conformer', 'librispeech_deepspeech']:
28+
return '/opt/data/librispeech'
29+
elif workload == 'criteo1tb':
30+
return '/opt/data/criteo1tb'
31+
elif workload == 'fastmri':
32+
return '/opt/data/fastmri'
33+
elif workload == 'ogbg':
34+
return '/opt/data/ogbg'
35+
elif workload == 'wmt':
36+
return '/opt/data/wmt'
37+
else:
38+
return '/opt/'
39+
40+
41+
def run_workload(workload: str, framework: str, output_file: Path) -> bool:
42+
"""Run a workload and capture output to file."""
43+
data_dir = get_data_dir(workload, framework)
44+
experiment_dir = '/home/ak4605/experiments'
45+
46+
# Clean up previous experiment directories
47+
for item in Path(experiment_dir).glob(f'{workload}*'):
48+
if item.is_dir():
49+
subprocess.run(['rm', '-rf', str(item)], check=True)
50+
51+
# Build command based on framework
52+
submission_path = (
53+
f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py'
54+
)
55+
tuning_search_space = (
56+
'algorithms/baselines/external_tuning/tuning_search_space.json'
57+
)
58+
59+
if framework == 'jax':
60+
cmd = [
61+
'python',
62+
'submission_runner.py',
63+
f'--framework={framework}',
64+
f'--workload={workload}',
65+
f'--data_dir={data_dir}',
66+
f'--experiment_dir={experiment_dir}',
67+
f'--experiment_name={workload}_benchmark',
68+
f'--submission_path={submission_path}',
69+
f'--tuning_search_space={tuning_search_space}',
70+
f'--max_global_steps={MAX_STEPS}',
71+
'--skip_evals',
72+
'--nosave_checkpoints',
73+
'--nosave_intermediate_checkpoints',
74+
]
75+
# For JAX, activate the jax conda environment
76+
activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_jax && '
77+
else:
78+
cmd = [
79+
'torchrun',
80+
'--nproc_per_node=4',
81+
'--standalone',
82+
'submission_runner.py',
83+
f'--framework={framework}',
84+
f'--workload={workload}',
85+
f'--data_dir={data_dir}',
86+
f'--experiment_dir={experiment_dir}',
87+
f'--experiment_name={workload}_benchmark',
88+
f'--submission_path={submission_path}',
89+
f'--tuning_search_space={tuning_search_space}',
90+
f'--max_global_steps={MAX_STEPS}',
91+
'--skip_evals',
92+
'--nosave_checkpoints',
93+
'--nosave_intermediate_checkpoints',
94+
]
95+
# For PyTorch, activate the torch conda environment
96+
activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_torch_latest && '
97+
98+
# Run the command with shell to handle conda activation
99+
full_cmd = activate_cmd + ' '.join(cmd)
100+
print(f'Running: {workload} with {framework}')
101+
print(f'Output will be saved to: {output_file}')
102+
103+
with open(output_file, 'w') as f:
104+
result = subprocess.run(
105+
full_cmd,
106+
shell=True,
107+
executable='/bin/bash',
108+
stdout=f,
109+
stderr=subprocess.STDOUT,
110+
cwd='/home/ak4605/aef2/',
111+
)
112+
113+
return result.returncode == 0
114+
115+
116+
def parse_step_time(output_file: Path) -> float | None:
117+
"""Parse the last step_time_ms from output file."""
118+
if not output_file.exists():
119+
return None
120+
121+
with open(output_file, 'r') as f:
122+
content = f.read()
123+
124+
# Find all step_time_ms values
125+
# Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456
126+
pattern = r'step_time_ms[=:]\s*([\d.]+)'
127+
matches = re.findall(pattern, content)
128+
129+
if matches:
130+
# Return the last value (most recent EMA)
131+
return float(matches[-1])
132+
return None
133+
134+
135+
def parse_args():
136+
parser = argparse.ArgumentParser(
137+
description='Benchmark step times for JAX and PyTorch across workloads.'
138+
)
139+
group = parser.add_mutually_exclusive_group()
140+
group.add_argument(
141+
'--torch-only',
142+
action='store_true',
143+
help='Only run PyTorch experiments; read existing JAX results from files.',
144+
)
145+
group.add_argument(
146+
'--jax-only',
147+
action='store_true',
148+
help='Only run JAX experiments; read existing PyTorch results from files.',
149+
)
150+
group.add_argument(
151+
'--just-read',
152+
action='store_true',
153+
help='Do not run any experiments; just read and compare existing outputs.',
154+
)
155+
return parser.parse_args()
156+
157+
158+
def main():
159+
args = parse_args()
160+
161+
# Create output directory
162+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
163+
164+
results = {}
165+
166+
# Determine which frameworks to run vs read from files
167+
if args.just_read:
168+
frameworks_to_run = []
169+
frameworks_to_read = FRAMEWORKS
170+
elif args.torch_only:
171+
frameworks_to_run = ['pytorch']
172+
frameworks_to_read = ['jax']
173+
elif args.jax_only:
174+
frameworks_to_run = ['jax']
175+
frameworks_to_read = ['pytorch']
176+
else:
177+
frameworks_to_run = FRAMEWORKS
178+
frameworks_to_read = []
179+
180+
# Run all workloads
181+
for workload in WORKLOADS:
182+
results[workload] = {}
183+
184+
# Read existing results from files
185+
for framework in frameworks_to_read:
186+
output_file = OUTPUT_DIR / f'{workload}_{framework}.out'
187+
step_time = parse_step_time(output_file)
188+
results[workload][framework] = step_time
189+
if step_time:
190+
print(f'\nLoaded existing {framework.upper()} result for {workload}: {step_time:.2f} ms')
191+
else:
192+
print(f'\nNo existing {framework.upper()} result found for {workload}')
193+
194+
# Run experiments for specified frameworks
195+
for framework in frameworks_to_run:
196+
output_file = OUTPUT_DIR / f'{workload}_{framework}.out'
197+
198+
print(f'\n{"=" * 60}')
199+
print(f'Benchmarking {workload} with {framework}')
200+
print(f'{"=" * 60}')
201+
202+
success = run_workload(workload, framework, output_file)
203+
204+
if success:
205+
step_time = parse_step_time(output_file)
206+
results[workload][framework] = step_time
207+
print(
208+
f'Step time: {step_time:.2f} ms' if step_time else 'Step time: N/A'
209+
)
210+
else:
211+
results[workload][framework] = None
212+
print(f'Failed to run {workload} with {framework}')
213+
214+
# Print results table
215+
print('\n\n')
216+
print('=' * 80)
217+
print('STEP TIME COMPARISON (ms)')
218+
print('=' * 80)
219+
print(
220+
f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}'
221+
)
222+
print('-' * 80)
223+
224+
for workload in WORKLOADS:
225+
jax_time = results[workload].get('jax')
226+
pytorch_time = results[workload].get('pytorch')
227+
228+
jax_str = f'{jax_time:.2f}' if jax_time else 'N/A'
229+
pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A'
230+
231+
if jax_time and pytorch_time:
232+
ratio = pytorch_time / jax_time
233+
ratio_str = f'{ratio:.2f}x'
234+
else:
235+
ratio_str = 'N/A'
236+
237+
print(f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}')
238+
239+
print('=' * 80)
240+
241+
# Save results to file
242+
results_file = OUTPUT_DIR / 'results.txt'
243+
with open(results_file, 'w') as f:
244+
f.write('STEP TIME COMPARISON (ms)\n')
245+
f.write('=' * 80 + '\n')
246+
f.write(
247+
f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}\n'
248+
)
249+
f.write('-' * 80 + '\n')
250+
251+
for workload in WORKLOADS:
252+
jax_time = results[workload].get('jax')
253+
pytorch_time = results[workload].get('pytorch')
254+
255+
jax_str = f'{jax_time:.2f}' if jax_time else 'N/A'
256+
pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A'
257+
258+
if jax_time and pytorch_time:
259+
ratio = pytorch_time / jax_time
260+
ratio_str = f'{ratio:.2f}x'
261+
else:
262+
ratio_str = 'N/A'
263+
264+
f.write(
265+
f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}\n'
266+
)
267+
268+
f.write('=' * 80 + '\n')
269+
270+
print(f'\nResults saved to: {results_file}')
271+
272+
273+
if __name__ == '__main__':
274+
main()

submission_runner.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def train_once(
352352
log_dir, flags.FLAGS, hyperparameters
353353
)
354354
workload.attach_metrics_logger(metrics_logger)
355-
355+
step_10_end_time = None
356356
global_start_time = get_time()
357357
train_state['last_step_end_time'] = global_start_time
358358

@@ -409,6 +409,22 @@ def train_once(
409409
train_state['training_complete'] = True
410410

411411
train_step_end_time = get_time()
412+
if global_step == 11:
413+
step_10_end_time = train_step_end_time
414+
415+
# Log step time every 100 steps
416+
# Note: global_step was incremented, so use (global_step - 1) to match
417+
if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None:
418+
if step_10_end_time is not None and global_step > 11:
419+
elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0
420+
elapsed_steps = global_step - 11
421+
avg_step_time_ms = elapsed_time_ms / elapsed_steps
422+
else:
423+
avg_step_time_ms = 0.0
424+
workload.metrics_logger.append_scalar_metrics(
425+
{'step_time_ms': avg_step_time_ms},
426+
global_step - 1,
427+
)
412428

413429
train_state['accumulated_submission_time'] += (
414430
train_step_end_time - train_state['last_step_end_time']

0 commit comments

Comments
 (0)