|
19 | 19 | import flax |
20 | 20 | import tqdm |
21 | 21 | import jax.numpy as jnp |
22 | | - |
| 22 | +import re |
23 | 23 | import optax |
24 | 24 | import time |
25 | 25 | import os |
@@ -140,8 +140,8 @@ def boolean_string(s): |
140 | 140 | parser.add_argument('--experiment_name', type=str, default=None, help='Experiment name, would be generated if not provided') |
141 | 141 | parser.add_argument('--load_from_checkpoint', type=str, |
142 | 142 | default=None, help='Load from the best previously stored checkpoint. The checkpoint path should be provided') |
143 | | -parser.add_argument('--resume_last_run', type=boolean_string, |
144 | | - default=False, help='Resume the last run from the experiment name') |
| 143 | +parser.add_argument('--resume_last_run', type=str, |
| 144 | + default=None, help='Resume the last run from the experiment name') |
145 | 145 | parser.add_argument('--dataset_seed', type=int, default=0, help='Dataset starting seed') |
146 | 146 |
|
147 | 147 | parser.add_argument('--dataset_test', type=boolean_string, |
@@ -452,29 +452,30 @@ def main(args): |
452 | 452 | karas_ve_schedule = KarrasVENoiseScheduler( |
453 | 453 | 1, sigma_max=80, rho=7, sigma_data=0.5) |
454 | 454 | edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5) |
455 | | - |
456 | | - if args.experiment_name and args.experiment_name != "": |
| 455 | + |
| 456 | + if args.experiment_name is not None: |
457 | 457 | experiment_name = args.experiment_name |
458 | | - if not args.resume_last_run: |
459 | | - experiment_name = f"{experiment_name}/" + "arguments_hash-{arguments_hash}/date-{date}" |
460 | | - else: |
461 | | - # TODO: Add logic to load the last run from wandb |
462 | | - pass |
463 | 458 | else: |
464 | | - experiment_name = "{name}_{date}".format( |
465 | | - name="Diffusion_SDE_VE_TEXT", date=datetime.now().strftime("%Y-%m-%d_%H:%M:%S") |
466 | | - ) |
| 459 | + experiment_name = "manual-dataset-{dataset}/image_size-{image_size}/batch-{batch_size}/schd-{noise_schedule}/dtype-{dtype}/arch-{architecture}/lr-{learning_rate}/resblks-{num_res_blocks}/emb-{emb_features}/pure-attn-{only_pure_attention}" |
467 | 460 |
|
468 | | - if autoencoder is not None: |
469 | | - experiment_name = f"LDM-{experiment_name}" |
470 | | - |
471 | | - if args.use_hilbert: |
472 | | - experiment_name = f"Hilbert-{experiment_name}" |
473 | | - |
474 | | - conf_args = CONFIG['arguments'] |
475 | | - conf_args['date'] = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") |
476 | | - conf_args['arguments_hash'] = arguments_hash |
477 | | - experiment_name = experiment_name.format(**conf_args) |
| 461 | + # Check if format strings are required using regex |
| 462 | + pattern = r"\{.+?\}" |
| 463 | + if re.search(pattern, experiment_name): |
| 464 | + experiment_name = f"{experiment_name}/" + "arguments_hash-{arguments_hash}/date-{date}" |
| 465 | + if autoencoder is not None: |
| 466 | + experiment_name = f"LDM-{experiment_name}" |
| 467 | + |
| 468 | + if args.use_hilbert: |
| 469 | + experiment_name = f"Hilbert-{experiment_name}" |
| 470 | + |
| 471 | + conf_args = CONFIG['arguments'] |
| 472 | + conf_args['date'] = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") |
| 473 | + conf_args['arguments_hash'] = arguments_hash |
| 474 | + # Format the string with the arguments |
| 475 | + experiment_name = experiment_name.format(**vars(args)) |
| 476 | + else: |
| 477 | + # If no format strings, just use the provided name |
| 478 | + experiment_name = args.experiment_name |
478 | 479 |
|
479 | 480 | print("Experiment_Name:", experiment_name) |
480 | 481 |
|
@@ -511,6 +512,9 @@ def main(args): |
511 | 512 | "name": experiment_name, |
512 | 513 | } |
513 | 514 |
|
| 515 | + if args.resume_last_run is not None: |
| 516 | + wandb_config['id'] = args.resume_last_run |
| 517 | + |
514 | 518 | start_time = time.time() |
515 | 519 |
|
516 | 520 | trainer = GeneralDiffusionTrainer( |
|
0 commit comments