44import torch .nn as nn
55import torch .optim as optim
66import wandb
7- import time
7+ import uuid
8+ import json
89from accelerate import Accelerator
910from torch .optim .lr_scheduler import LambdaLR
1011from tqdm import tqdm
2122)
2223from aging_gan .data import prepare_dataset
2324from aging_gan .model import initialize_models
24- from aging_gan .utils import archive_and_terminate
25+ from aging_gan .utils import terminate_ec2 , archive_ec2
2526
2627logger = logging .getLogger (__name__ )
2728
@@ -106,7 +107,7 @@ def parse_args() -> argparse.Namespace:
106107 p .add_argument (
107108 "--num_workers" ,
108109 type = int ,
109- default = 2 ,
110+ default = 3 ,
110111 help = "Number of workers for dataloaders." ,
111112 )
112113 p .add_argument (
@@ -120,6 +121,12 @@ def parse_args() -> argparse.Namespace:
120121 action = "store_true" ,
121122 help = "Upload outputs/ into s3 bucket and terminate current ec2." ,
122123 )
124+ p .add_argument (
125+ "--s3_bucket_name" ,
126+ type = str ,
127+ default = "aging-gan" ,
128+ help = "Name of your s3 bucket to sync outputs. Effective only when --archive_and_terminate_ec2." ,
129+ )
123130
124131 p .add_argument ("--wandb_project" , type = str , default = "aging-gan" )
125132
@@ -160,7 +167,9 @@ def initialize_optimizers(cfg, G, F, DX, DY):
160167
161168
162169def initialize_loss_functions (
163- lambda_adv_value : float = 2.0 , lambda_cyc_value : float = 10.0 , lambda_id_value : float = 7.0
170+ lambda_adv_value : float = 2.0 ,
171+ lambda_cyc_value : float = 10.0 ,
172+ lambda_id_value : float = 7.0 ,
164173):
165174 mse = nn .MSELoss ()
166175 l1 = nn .L1Loss ()
@@ -500,8 +509,10 @@ def main() -> None:
500509
501510 # ---------- Run Initialization ----------
502511 # wandb
512+ run_name = uuid .uuid4 ().hex [:8 ]
503513 wandb .init (
504514 project = cfg .wandb_project ,
515+ name = run_name ,
505516 config = {
506517 k : v for k , v in vars (cfg ).items () if not k .startswith ("_" )
507518 }, # drop python's or "private" framework-internal attributes
@@ -644,6 +655,9 @@ def main() -> None:
644655 sched_DY , # schedulers
645656 "best" ,
646657 )
658+ # upload outputs to s3 bucket
659+ if cfg .archive_and_terminate_ec2 :
660+ archive_ec2 (bucket = cfg .s3_bucket_name , prefix = f"outputs/run-{ run_name } " )
647661 # save the latest checkpoint
648662 if epoch % 5 == 0 :
649663 save_checkpoint (
@@ -662,6 +676,9 @@ def main() -> None:
662676 sched_DY , # schedulers
663677 "current" ,
664678 )
679+ # upload outputs to s3 bucket
680+ if cfg .archive_and_terminate_ec2 :
681+ archive_ec2 (bucket = cfg .s3_bucket_name , prefix = f"outputs/run-{ run_name } " )
665682
666683 # ---------- Test ----------
667684 if cfg .do_test :
@@ -702,17 +719,25 @@ def main() -> None:
702719 )
703720 logger .info (f"Test metrics (best.pth):\n { test_metrics } " )
704721 wandb .log (test_metrics )
722+ # write metrics out
723+ out_dir = Path (__file__ ).resolve ().parents [2 ] / "outputs" / "metrics"
724+ out_dir .mkdir (parents = True , exist_ok = True )
725+ metrics_path = out_dir / "test_metrics.json"
726+ with metrics_path .open ("w" ) as f :
727+ json .dump (test_metrics , f , indent = 4 )
728+ print (f"Saved test metrics to { metrics_path } " )
729+ # upload outputs to s3 bucket
730+ if cfg .archive_and_terminate_ec2 :
731+ archive_ec2 (bucket = cfg .s3_bucket_name , prefix = f"outputs/run-{ run_name } " )
705732 else :
706733 logger .info ("Skipping test evaluation..." )
707734
708735 # Finished
709736 logger .info ("Finished run." )
710737
711- # upload outputs to s3 bucket and terminate ec2 instance
738+ # terminate ec2 instance
712739 if cfg .archive_and_terminate_ec2 :
713- archive_and_terminate (
714- bucket = "aging-gan" , prefix = f"outputs/run-{ int (time .time ())} "
715- )
740+ terminate_ec2 ()
716741
717742
718743if __name__ == "__main__" :
0 commit comments