Skip to content

Commit b9741fe

Browse files
committed
tuned hyperparams; edited save checkpoint to every 5 batches
1 parent 65ff72f commit b9741fe

File tree

3 files changed

+27
-33
lines changed

3 files changed

+27
-33
lines changed

src/aging_gan/data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def prepare_dataset(
124124
eval_batch_size: int = 8,
125125
num_workers: int = 2,
126126
img_size: int = 256,
127-
resize_size: int = 286,
128127
seed: int = 42,
129128
):
130129
data_dir = Path(__file__).resolve().parents[2] / "data"
@@ -146,7 +145,7 @@ def prepare_dataset(
146145
# deterministic
147146
eval_transform = T.Compose(
148147
[
149-
T.Resize(resize_size, antialias=True),
148+
T.Resize((img_size + 50, img_size + 50), antialias=True),
150149
T.CenterCrop(img_size),
151150
T.ToTensor(),
152151
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),

src/aging_gan/train.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,28 @@ def parse_args() -> argparse.Namespace:
3434
p.add_argument(
3535
"--gen_lr",
3636
type=float,
37-
default=3e-4,
37+
default=1e-4,
3838
help="Initial learning rate for generators.",
3939
)
4040
p.add_argument(
4141
"--disc_lr",
4242
type=float,
43-
default=2e-4,
43+
default=1e-4,
4444
help="Initial learning rate for discriminators.",
4545
)
4646
p.add_argument(
47-
"--num_train_epochs", type=int, default=25, help="Number of training epochs."
47+
"--num_train_epochs", type=int, default=100, help="Number of training epochs."
4848
)
4949
p.add_argument(
5050
"--train_batch_size",
5151
type=int,
52-
default=16,
52+
default=4,
5353
help="Batch size per device during training.",
5454
)
5555
p.add_argument(
5656
"--eval_batch_size",
5757
type=int,
58-
default=32,
58+
default=8,
5959
help="Batch size per device during evaluation.",
6060
)
6161
p.add_argument(
@@ -67,7 +67,7 @@ def parse_args() -> argparse.Namespace:
6767
p.add_argument(
6868
"--lambda_cyc_value",
6969
type=int,
70-
default=7,
70+
default=10,
7171
help="Weight for cyclical loss",
7272
)
7373
p.add_argument(
@@ -164,7 +164,7 @@ def initialize_optimizers(cfg, G, F, DX, DY):
164164

165165

166166
def initialize_loss_functions(
167-
lambda_adv_value: int = 2, lambda_cyc_value: int = 10, lambda_id_value: int = 5
167+
lambda_adv_value: int = 2, lambda_cyc_value: int = 10, lambda_id_value: int = 7
168168
):
169169
mse = nn.MSELoss()
170170
l1 = nn.L1Loss()
@@ -643,22 +643,23 @@ def main() -> None:
643643
"best",
644644
)
645645
# save the latest checkpoint
646-
save_checkpoint(
647-
epoch,
648-
G,
649-
F,
650-
DX,
651-
DY,
652-
opt_G,
653-
opt_F, # generator optimizers
654-
opt_DX,
655-
opt_DY, # discriminator optimizers
656-
sched_G,
657-
sched_F,
658-
sched_DX,
659-
sched_DY, # schedulers
660-
"latest",
661-
)
646+
if epoch % 5 == 0:
647+
save_checkpoint(
648+
epoch,
649+
G,
650+
F,
651+
DX,
652+
DY,
653+
opt_G,
654+
opt_F, # generator optimizers
655+
opt_DX,
656+
opt_DY, # discriminator optimizers
657+
sched_G,
658+
sched_F,
659+
sched_DX,
660+
sched_DY, # schedulers
661+
"current",
662+
)
662663

663664
# ---------- Test ----------
664665
if cfg.do_test:

src/aging_gan/utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def save_checkpoint(
6464
sched_DY, # schedulers
6565
kind: str = "best",
6666
):
67-
"""Overwrite the single bestever checkpoint."""
67+
"""Overwrite the single best-ever checkpoint."""
6868
ckpt_dir = Path(__file__).resolve().parents[2] / "outputs/checkpoints"
6969
os.makedirs(ckpt_dir, exist_ok=True)
7070

@@ -88,15 +88,9 @@ def save_checkpoint(
8888
filename = os.path.join(ckpt_dir, "best.pth")
8989
torch.save(state, filename)
9090
logger.info(f"Saved best checkpoint: {filename}")
91-
elif kind == "latest":
91+
elif kind == "current":
9292
new_latest = ckpt_dir / f"epoch_{epoch:04d}.pth"
9393
torch.save(state, new_latest)
94-
95-
# remove previous epoch_*.pth checkpoints
96-
for f in ckpt_dir.glob("epoch_*.pth"):
97-
if f != new_latest:
98-
f.unlink(missing_ok=True)
99-
10094
logger.info(f"Saved latest checkpoint: {new_latest}")
10195
else:
10296
raise ValueError(f"kind must be 'best' or 'latest', got {kind}")

0 commit comments

Comments
 (0)