Skip to content

Commit e3322b7

Browse files
committed
feat: resume from last run
1 parent 8cbae96 commit e3322b7

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

training.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import flax
2020
import tqdm
2121
import jax.numpy as jnp
22-
22+
import re
2323
import optax
2424
import time
2525
import os
@@ -140,8 +140,8 @@ def boolean_string(s):
140140
parser.add_argument('--experiment_name', type=str, default=None, help='Experiment name, would be generated if not provided')
141141
parser.add_argument('--load_from_checkpoint', type=str,
142142
default=None, help='Load from the best previously stored checkpoint. The checkpoint path should be provided')
143-
parser.add_argument('--resume_last_run', type=boolean_string,
144-
default=False, help='Resume the last run from the experiment name')
143+
parser.add_argument('--resume_last_run', type=str,
144+
default=None, help='Resume the last run from the experiment name')
145145
parser.add_argument('--dataset_seed', type=int, default=0, help='Dataset starting seed')
146146

147147
parser.add_argument('--dataset_test', type=boolean_string,
@@ -452,29 +452,30 @@ def main(args):
452452
karas_ve_schedule = KarrasVENoiseScheduler(
453453
1, sigma_max=80, rho=7, sigma_data=0.5)
454454
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
455-
456-
if args.experiment_name and args.experiment_name != "":
455+
456+
if args.experiment_name is not None:
457457
experiment_name = args.experiment_name
458-
if not args.resume_last_run:
459-
experiment_name = f"{experiment_name}/" + "arguments_hash-{arguments_hash}/date-{date}"
460-
else:
461-
# TODO: Add logic to load the last run from wandb
462-
pass
463458
else:
464-
experiment_name = "{name}_{date}".format(
465-
name="Diffusion_SDE_VE_TEXT", date=datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
466-
)
459+
experiment_name = "manual-dataset-{dataset}/image_size-{image_size}/batch-{batch_size}/schd-{noise_schedule}/dtype-{dtype}/arch-{architecture}/lr-{learning_rate}/resblks-{num_res_blocks}/emb-{emb_features}/pure-attn-{only_pure_attention}"
467460

468-
if autoencoder is not None:
469-
experiment_name = f"LDM-{experiment_name}"
470-
471-
if args.use_hilbert:
472-
experiment_name = f"Hilbert-{experiment_name}"
473-
474-
conf_args = CONFIG['arguments']
475-
conf_args['date'] = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
476-
conf_args['arguments_hash'] = arguments_hash
477-
experiment_name = experiment_name.format(**conf_args)
461+
# Check if format strings are required using regex
462+
pattern = r"\{.+?\}"
463+
if re.search(pattern, experiment_name):
464+
experiment_name = f"{experiment_name}/" + "arguments_hash-{arguments_hash}/date-{date}"
465+
if autoencoder is not None:
466+
experiment_name = f"LDM-{experiment_name}"
467+
468+
if args.use_hilbert:
469+
experiment_name = f"Hilbert-{experiment_name}"
470+
471+
conf_args = CONFIG['arguments']
472+
conf_args['date'] = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
473+
conf_args['arguments_hash'] = arguments_hash
474+
# Format the string with the arguments
475+
experiment_name = experiment_name.format(**vars(args))
476+
else:
477+
# If no format strings, just use the provided name
478+
experiment_name = args.experiment_name
478479

479480
print("Experiment_Name:", experiment_name)
480481

@@ -511,6 +512,9 @@ def main(args):
511512
"name": experiment_name,
512513
}
513514

515+
if args.resume_last_run is not None:
516+
wandb_config['id'] = args.resume_last_run
517+
514518
start_time = time.time()
515519

516520
trainer = GeneralDiffusionTrainer(

0 commit comments

Comments
 (0)