Skip to content

Commit 3049cbc

Browse files
committed
separate archiving from terminating ec2 for more frequent output dir sync with s3; save test metrics too
1 parent c47807b commit 3049cbc

File tree

2 files changed

+59
-27
lines changed

2 files changed

+59
-27
lines changed

src/aging_gan/train.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch.nn as nn
55
import torch.optim as optim
66
import wandb
7-
import time
7+
import uuid
8+
import json
89
from accelerate import Accelerator
910
from torch.optim.lr_scheduler import LambdaLR
1011
from tqdm import tqdm
@@ -21,7 +22,7 @@
2122
)
2223
from aging_gan.data import prepare_dataset
2324
from aging_gan.model import initialize_models
24-
from aging_gan.utils import archive_and_terminate
25+
from aging_gan.utils import terminate_ec2, archive_ec2
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -106,7 +107,7 @@ def parse_args() -> argparse.Namespace:
106107
p.add_argument(
107108
"--num_workers",
108109
type=int,
109-
default=2,
110+
default=3,
110111
help="Number of workers for dataloaders.",
111112
)
112113
p.add_argument(
@@ -120,6 +121,12 @@ def parse_args() -> argparse.Namespace:
120121
action="store_true",
121122
help="Upload outputs/ into s3 bucket and terminate current ec2.",
122123
)
124+
p.add_argument(
125+
"--s3_bucket_name",
126+
type=str,
127+
default="aging-gan",
128+
help="Name of your s3 bucket to sync outputs. Effective only when --archive_and_terminate_ec2.",
129+
)
123130

124131
p.add_argument("--wandb_project", type=str, default="aging-gan")
125132

@@ -160,7 +167,9 @@ def initialize_optimizers(cfg, G, F, DX, DY):
160167

161168

162169
def initialize_loss_functions(
163-
lambda_adv_value: float = 2.0, lambda_cyc_value: float = 10.0, lambda_id_value: float = 7.0
170+
lambda_adv_value: float = 2.0,
171+
lambda_cyc_value: float = 10.0,
172+
lambda_id_value: float = 7.0,
164173
):
165174
mse = nn.MSELoss()
166175
l1 = nn.L1Loss()
@@ -500,8 +509,10 @@ def main() -> None:
500509

501510
# ---------- Run Initialization ----------
502511
# wandb
512+
run_name = uuid.uuid4().hex[:8]
503513
wandb.init(
504514
project=cfg.wandb_project,
515+
name=run_name,
505516
config={
506517
k: v for k, v in vars(cfg).items() if not k.startswith("_")
507518
}, # drop python's or "private" framework-internal attributes
@@ -644,6 +655,9 @@ def main() -> None:
644655
sched_DY, # schedulers
645656
"best",
646657
)
658+
# upload outputs to s3 bucket
659+
if cfg.archive_and_terminate_ec2:
660+
archive_ec2(bucket=cfg.s3_bucket_name, prefix=f"outputs/run-{run_name}")
647661
# save the latest checkpoint
648662
if epoch % 5 == 0:
649663
save_checkpoint(
@@ -662,6 +676,9 @@ def main() -> None:
662676
sched_DY, # schedulers
663677
"current",
664678
)
679+
# upload outputs to s3 bucket
680+
if cfg.archive_and_terminate_ec2:
681+
archive_ec2(bucket=cfg.s3_bucket_name, prefix=f"outputs/run-{run_name}")
665682

666683
# ---------- Test ----------
667684
if cfg.do_test:
@@ -702,17 +719,25 @@ def main() -> None:
702719
)
703720
logger.info(f"Test metrics (best.pth):\n{test_metrics}")
704721
wandb.log(test_metrics)
722+
# write metrics out
723+
out_dir = Path(__file__).resolve().parents[2] / "outputs" / "metrics"
724+
out_dir.mkdir(parents=True, exist_ok=True)
725+
metrics_path = out_dir / "test_metrics.json"
726+
with metrics_path.open("w") as f:
727+
json.dump(test_metrics, f, indent=4)
728+
print(f"Saved test metrics to {metrics_path}")
729+
# upload outputs to s3 bucket
730+
if cfg.archive_and_terminate_ec2:
731+
archive_ec2(bucket=cfg.s3_bucket_name, prefix=f"outputs/run-{run_name}")
705732
else:
706733
logger.info("Skipping test evaluation...")
707734

708735
# Finished
709736
logger.info("Finished run.")
710737

711-
# upload outputs to s3 bucket and terminate ec2 instance
738+
# terminate ec2 instance
712739
if cfg.archive_and_terminate_ec2:
713-
archive_and_terminate(
714-
bucket="aging-gan", prefix=f"outputs/run-{int(time.time())}"
715-
)
740+
terminate_ec2()
716741

717742

718743
if __name__ == "__main__":

src/aging_gan/utils.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import torch
77
import matplotlib.pyplot as plt
8+
import subprocess
89
import boto3
910
import time
1011
from dotenv import load_dotenv
@@ -103,7 +104,7 @@ def generate_and_save_samples(
103104
device: torch.device,
104105
num_samples: int = 8,
105106
):
106-
# grab batches until num_samples
107+
# grab batches until num_samples
107108
collected = []
108109
for imgs, _ in val_loader:
109110
collected.append(imgs)
@@ -112,9 +113,9 @@ def generate_and_save_samples(
112113

113114
if not collected:
114115
raise ValueError("Validation loader is empty.")
115-
116+
116117
inputs = torch.cat(collected, dim=0)[:num_samples].to(device)
117-
118+
118119
with torch.no_grad():
119120
outputs = generator(inputs)
120121

@@ -149,29 +150,35 @@ def generate_and_save_samples(
149150
plt.close(fig)
150151

151152

152-
def archive_and_terminate(
153+
def archive_ec2(
153154
bucket: str,
154-
prefix: str = "outputs/boto3/",
155+
prefix: str = "outputs",
155156
) -> None:
157+
"""Syncs everything under ./outputs to `s3://{bucket}/{prefix}/`."""
158+
# Upload
159+
out_root = Path(__file__).resolve().parents[2] / "outputs"
160+
print(f"Uploading {out_root} -> s3://{bucket}/{prefix}/ ...")
161+
cmd = [
162+
"aws",
163+
"s3",
164+
"sync",
165+
str(out_root),
166+
f"s3://{bucket}/{prefix}",
167+
"--only-show-errors", # quieter logging
168+
]
169+
subprocess.run(cmd, check=True)
170+
print("S3 sync complete")
171+
172+
173+
def terminate_ec2() -> None:
156174
"""
157-
1. Recursively uploads everything under ./outputs to `s3://{bucket}/{prefix}/`.
158-
2. Calls the EC2 API to terminate *this* instance.
175+
Calls the EC2 API to terminate this instance.
159176
160177
The instance must run with an IAM role that can:
161178
s3:PutObject on arn:aws:s3:::{bucket}/*
162179
ec2:TerminateInstances on itself (resource‑level ARN)
163180
"""
164-
# Upload
165-
s3 = boto3.client("s3")
166-
out_root = Path(__file__).resolve().parents[2] / "outputs/"
167-
print(f"Uploading {out_root} -> s3://{bucket}/{prefix}/ ...")
168-
for fp in out_root.rglob("*"):
169-
if fp.is_file():
170-
key = f"{prefix}/{fp.relative_to(out_root)}"
171-
s3.upload_file(str(fp), bucket, key)
172-
print("S3 sync complete")
173-
174-
# ---------- 2. Gather instance metadata (IMDSv2) ---------------------------------------
181+
# Gather instance metadata (IMDSv2)
175182
token = requests.put(
176183
"http://169.254.169.254/latest/api/token",
177184
headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"},
@@ -190,7 +197,7 @@ def archive_and_terminate(
190197
).text
191198
print(f"Terminating {instance_id} in {region}")
192199

193-
# ---------- 3. Terminate self ----------------------------------------------------------
200+
# Terminate self
194201
ec2 = boto3.client("ec2", region_name=region)
195202
ec2.terminate_instances(InstanceIds=[instance_id])
196203
print("Termination request sent - instance will shut down shortly")

0 commit comments

Comments
 (0)