Skip to content

Commit 7d5229d

Browse files
Merge pull request #661 from mlcommons/self_tuning_docker
[WIP] Self tuning workflow fixes
2 parents 77f1367 + 2d7c27a commit 7d5229d

File tree

6 files changed

+150
-34
lines changed

6 files changed

+150
-34
lines changed

docker/scripts/startup.sh

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ HOME_DIR=""
5050
RSYNC_DATA="true"
5151
OVERWRITE="false"
5252
SAVE_CHECKPOINTS="true"
53+
TUNING_RULESET="external"
5354

5455
# Pass flag
5556
while [ "$1" != "" ]; do
@@ -107,6 +108,10 @@ while [ "$1" != "" ]; do
107108
shift
108109
HOME_DIR=$1
109110
;;
111+
--tuning_ruleset)
112+
shift
113+
TUNING_RULESET=$1
114+
;;
110115
--num_tuning_trials)
111116
shift
112117
NUM_TUNING_TRIALS=$1
@@ -157,6 +162,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_
157162
"librispeech_deepspeech_tanh" \
158163
"librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug"
159164
"fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size")
165+
VALID_RULESETS=("self" "external")
160166

161167
# Set data and experiment paths
162168
ROOT_DATA_BUCKET="gs://mlcommons-data"
@@ -167,17 +173,25 @@ EXPERIMENT_DIR="${HOME_DIR}/experiment_runs"
167173

168174
if [[ -n ${DATASET+x} ]]; then
169175
if [[ ! " ${VALID_DATASETS[@]} " =~ " $DATASET " ]]; then
170-
echo "Error: invalid argument for dataset (d)."
176+
echo "Error: invalid argument $DATASET for dataset (d)."
171177
exit 1
172178
fi
173179
fi
174180

175181
if [[ -n ${WORKLOAD+x} ]]; then
176182
if [[ ! " ${VALID_WORKLOADS[@]} " =~ " $WORKLOAD " ]]; then
177-
echo "Error: invalid argument for workload (w)."
183+
echo "Error: invalid argument $WORKLOAD for workload (w)."
184+
exit 1
185+
fi
186+
fi
187+
188+
if [[ -n ${TUNING_RULESET+x} ]]; then
189+
if [[ ! " ${VALID_RULESETS[@]} " =~ " $TUNING_RULESET " ]]; then
190+
echo "Error: invalid argument $TUNING_RULESET for tuning ruleset (tuning_ruleset)."
178191
exit 1
179192
fi
180193
fi
194+
TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}"
181195

182196
# Set run command prefix depending on framework
183197
if [[ "${FRAMEWORK}" == "jax" ]]; then
@@ -243,26 +257,42 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then
243257
if [[ ${FRAMEWORK} == "pytorch" ]]; then
244258
TORCH_COMPILE_FLAG="--torch_compile=true"
245259
fi
260+
261+
# Flags for rulesets
262+
if [[ ${TUNING_RULESET} == "external" ]]; then
263+
TUNING_SEARCH_SPACE_FLAG="--submission_path=${SUBMISSION_PATH}"
264+
fi
246265

247266
# The TORCH_RUN_COMMAND_PREFIX is only set if FRAMEWORK is "pytorch"
248-
COMMAND="${COMMAND_PREFIX} submission_runner.py \
267+
BASE_COMMAND="${COMMAND_PREFIX} submission_runner.py \
249268
--framework=${FRAMEWORK} \
250269
--workload=${WORKLOAD} \
251270
--submission_path=${SUBMISSION_PATH} \
252-
--tuning_search_space=${TUNING_SEARCH_SPACE} \
253271
--data_dir=${DATA_DIR} \
254272
--num_tuning_trials=1 \
255273
--experiment_dir=${EXPERIMENT_DIR} \
256274
--experiment_name=${EXPERIMENT_NAME} \
257275
--overwrite=${OVERWRITE} \
258276
--save_checkpoints=${SAVE_CHECKPOINTS} \
259-
${NUM_TUNING_TRIALS_FLAG} \
260-
${HPARAM_START_INDEX_FLAG} \
261-
${HPARAM_END_INDEX_FLAG} \
262277
${RNG_SEED_FLAG} \
263278
${MAX_STEPS_FLAG} \
264279
${SPECIAL_FLAGS} \
265-
${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}"
280+
${TORCH_COMPILE_FLAG}"
281+
282+
if [[ ${TUNING_RULESET} == "external" ]]; then
283+
COMMAND="${BASE_COMMAND} \
284+
${TUNING_RULESET_FLAG} \
285+
${TUNING_SEARCH_SPACE_FLAG} \
286+
${NUM_TUNING_TRIALS_FLAG} \
287+
${HPARAM_START_INDEX_FLAG} \
288+
${HPARAM_END_INDEX_FLAG}"
289+
else
290+
COMMAND="${BASE_COMMAND} \
291+
${TUNING_RULESET_FLAG}"
292+
fi
293+
294+
COMMAND="$COMMAND 2>&1 | tee -a ${LOG_FILE}"
295+
266296
echo $COMMAND > ${LOG_FILE}
267297
echo $COMMAND
268298
eval $COMMAND

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,14 @@ def init_optimizer_state(workload: spec.Workload,
180180

181181
def jax_cosine_warmup(step_hint: int, hyperparameters):
182182
# Create learning rate schedule.
183-
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
183+
warmup_steps = int(hyperparameters['warmup_factor'] * step_hint)
184184
warmup_fn = optax.linear_schedule(
185185
init_value=0.,
186-
end_value=hyperparameters.learning_rate,
186+
end_value=hyperparameters['learning_rate'],
187187
transition_steps=warmup_steps)
188188
cosine_steps = max(step_hint - warmup_steps, 1)
189189
cosine_fn = optax.cosine_decay_schedule(
190-
init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
190+
init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps)
191191
schedule_fn = optax.join_schedules(
192192
schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
193193
return schedule_fn
@@ -196,10 +196,10 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
196196
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
197197
opt_init_fn, opt_update_fn = nadamw(
198198
learning_rate=lr_schedule_fn,
199-
b1=1.0 - hyperparameters.one_minus_beta1,
200-
b2=hyperparameters.beta2,
199+
b1=1.0 - hyperparameters['one_minus_beta1'],
200+
b2=hyperparameters['beta2'],
201201
eps=1e-8,
202-
weight_decay=hyperparameters.weight_decay)
202+
weight_decay=hyperparameters['weight_decay'])
203203
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
204204
workload.param_shapes)
205205
optimizer_state = opt_init_fn(params_zeros_like)
@@ -286,11 +286,11 @@ def update_params(workload: spec.Workload,
286286
optimizer_state, opt_update_fn = optimizer_state
287287
per_device_rngs = jax.random.split(rng, jax.local_device_count())
288288
if hasattr(hyperparameters, 'label_smoothing'):
289-
label_smoothing = hyperparameters.label_smoothing
289+
label_smoothing = hyperparameters['label_smoothing']
290290
else:
291291
label_smoothing = 0.0
292292
if hasattr(hyperparameters, 'grad_clip'):
293-
grad_clip = hyperparameters.grad_clip
293+
grad_clip = hyperparameters['grad_clip']
294294
else:
295295
grad_clip = None
296296
outputs = pmapped_train_step(workload,

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,14 @@ def init_optimizer_state(workload: spec.Workload,
180180

181181
def jax_cosine_warmup(step_hint: int, hyperparameters):
182182
# Create learning rate schedule.
183-
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
183+
warmup_steps = int(hyperparameters['warmup_factor * step_hint'])
184184
warmup_fn = optax.linear_schedule(
185185
init_value=0.,
186-
end_value=hyperparameters.learning_rate,
186+
end_value=hyperparameters['learning_rate'],
187187
transition_steps=warmup_steps)
188188
cosine_steps = max(step_hint - warmup_steps, 1)
189189
cosine_fn = optax.cosine_decay_schedule(
190-
init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
190+
init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps)
191191
schedule_fn = optax.join_schedules(
192192
schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
193193
return schedule_fn
@@ -196,10 +196,10 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
196196
lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
197197
opt_init_fn, opt_update_fn = nadamw(
198198
learning_rate=lr_schedule_fn,
199-
b1=1.0 - hyperparameters.one_minus_beta1,
200-
b2=hyperparameters.beta2,
199+
b1=1.0 - hyperparameters['one_minus_beta1'],
200+
b2=hyperparameters['beta2'],
201201
eps=1e-8,
202-
weight_decay=hyperparameters.weight_decay)
202+
weight_decay=hyperparameters['weight_decay'])
203203
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
204204
workload.param_shapes)
205205
optimizer_state = opt_init_fn(params_zeros_like)
@@ -286,11 +286,11 @@ def update_params(workload: spec.Workload,
286286
optimizer_state, opt_update_fn = optimizer_state
287287
per_device_rngs = jax.random.split(rng, jax.local_device_count())
288288
if hasattr(hyperparameters, 'label_smoothing'):
289-
label_smoothing = hyperparameters.label_smoothing
289+
label_smoothing = hyperparameters['label_smoothing']
290290
else:
291291
label_smoothing = 0.0
292292
if hasattr(hyperparameters, 'grad_clip'):
293-
grad_clip = hyperparameters.grad_clip
293+
grad_clip = hyperparameters['grad_clip']
294294
else:
295295
grad_clip = None
296296
outputs = pmapped_train_step(workload,

scoring/run_workloads.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
False,
5151
'Whether or not to actually run the docker containers. '
5252
'If False, simply print the docker run commands. ')
53+
flags.DEFINE_enum(
54+
'tuning_ruleset',
55+
'external',
56+
enum_values=['external', 'self'],
57+
help='Can be either external of self.')
5358
flags.DEFINE_integer('num_studies', 5, 'Number of studies to run')
5459
flags.DEFINE_integer('study_start_index', None, 'Start index for studies.')
5560
flags.DEFINE_integer('study_end_index', None, 'End index for studies.')
@@ -66,11 +71,13 @@
6671
None,
6772
'Path to config containing held-out workloads')
6873
flags.DEFINE_string(
69-
'workload_meta_data_config_path',
70-
'workload_meta_data.json',
74+
'workload_metadata_path',
75+
None,
7176
'Path to config containing dataset and maximum number of steps per workload.'
7277
'The default values of these are set to the full budgets as determined '
7378
'via the target-setting procedure. '
79+
'We provide workload_metadata_external_tuning.json and '
80+
'workload_metadata_self_tuning.json as references.'
7481
'Note that training will be interrupted at either the set maximum number '
7582
'of steps or the fixed workload maximum run time, whichever comes first. '
7683
'If your algorithm has a smaller per step time than our baselines '
@@ -126,10 +133,10 @@ def main(_):
126133
logging.info('Using RNG seed %d', rng_seed)
127134
rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id)))
128135

129-
with open(FLAGS.workload_meta_data_config_path) as f:
130-
workload_meta_data = json.load(f)
136+
with open(FLAGS.workload_metadata_path) as f:
137+
workload_metadata = json.load(f)
131138

132-
workloads = [w for w in workload_meta_data.keys()]
139+
workloads = [w for w in workload_metadata.keys()]
133140

134141
# Read held-out workloads
135142
if FLAGS.held_out_workloads_config_path:
@@ -154,8 +161,8 @@ def main(_):
154161
os.system(
155162
"sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches
156163
print('=' * 100)
157-
dataset = workload_meta_data[base_workload_name]['dataset']
158-
max_steps = int(workload_meta_data[base_workload_name]['max_steps'] *
164+
dataset = workload_metadata[base_workload_name]['dataset']
165+
max_steps = int(workload_metadata[base_workload_name]['max_steps'] *
159166
run_fraction)
160167
mount_repo_flag = ''
161168
if FLAGS.local:
@@ -170,16 +177,26 @@ def main(_):
170177
f'-f {framework} '
171178
f'-s {submission_path} '
172179
f'-w {workload} '
173-
f'-t {tuning_search_space} '
174180
f'-e {study_dir} '
175181
f'-m {max_steps} '
176182
f'--num_tuning_trials {num_tuning_trials} '
177-
f'{hparam_start_index_flag} '
178-
f'{hparam_end_index_flag} '
179183
f'--rng_seed {run_seed} '
180184
'-c false '
181185
'-o true '
182186
'-i true ')
187+
188+
# Append tuning ruleset flags
189+
tuning_ruleset_flags = ''
190+
if FLAGS.tuning_ruleset == 'external':
191+
tuning_ruleset_flags += f'--tuning_ruleset {FLAGS.tuning_ruleset}'
192+
tuning_ruleset_flags += f'-t {tuning_search_space} '
193+
tuning_ruleset_flags += f'{hparam_start_index_flag} '
194+
tuning_ruleset_flags += f'{hparam_end_index_flag}'
195+
else:
196+
tuning_ruleset_flags += f'--tuning_ruleset {FLAGS.tuning_ruleset}'
197+
198+
command += tuning_ruleset_flags
199+
183200
if not FLAGS.dry_run:
184201
print('Running docker container command')
185202
print('Container ID: ')
@@ -205,4 +222,5 @@ def main(_):
205222

206223

207224
if __name__ == '__main__':
225+
flags.mark_flag_as_required('workload_metadata_path')
208226
app.run(main)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"imagenet_resnet": {
3+
"max_steps": 186666,
4+
"dataset": "imagenet"
5+
},
6+
"imagenet_vit": {
7+
"max_steps": 186666,
8+
"dataset": "imagenet"
9+
},
10+
"fastmri": {
11+
"max_steps": 36189,
12+
"dataset": "fastmri"
13+
},
14+
"ogbg": {
15+
"max_steps": 80000,
16+
"dataset": "ogbg"
17+
},
18+
"wmt": {
19+
"max_steps": 133333,
20+
"dataset": "wmt"
21+
},
22+
"librispeech_deepspeech": {
23+
"max_steps": 48000,
24+
"dataset": "librispeech"
25+
},
26+
"criteo1tb": {
27+
"max_steps": 10666,
28+
"dataset": "criteo1tb"
29+
},
30+
"librispeech_conformer": {
31+
"max_steps": 80000,
32+
"dataset": "librispeech"
33+
}
34+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"imagenet_resnet": {
3+
"max_steps": 559998,
4+
"dataset": "imagenet"
5+
},
6+
"imagenet_vit": {
7+
"max_steps": 559998,
8+
"dataset": "imagenet"
9+
},
10+
"fastmri": {
11+
"max_steps": 108567,
12+
"dataset": "fastmri"
13+
},
14+
"ogbg": {
15+
"max_steps": 240000,
16+
"dataset": "ogbg"
17+
},
18+
"wmt": {
19+
"max_steps": 399999,
20+
"dataset": "wmt"
21+
},
22+
"librispeech_deepspeech": {
23+
"max_steps": 144000,
24+
"dataset": "librispeech"
25+
},
26+
"criteo1tb": {
27+
"max_steps": 31998,
28+
"dataset": "criteo1tb"
29+
},
30+
"librispeech_conformer": {
31+
"max_steps": 240000,
32+
"dataset": "librispeech"
33+
}
34+
}

0 commit comments

Comments
 (0)