Skip to content

Commit 65ff72f

Browse files
committed
changed age limit and limit function for data preprocessing; removed not needed functions in utils and train
1 parent 1eb9d45 commit 65ff72f

File tree

3 files changed

+69
-89
lines changed

3 files changed

+69
-89
lines changed

src/aging_gan/data.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212

1313
class UTKFace(Dataset):
1414
"""
15-
Assumes the unzipped UTKFace images live in <root>/data/utkface_aligned_cropped/UTKFace
15+
Assumes the unzipped UTKFace images live in <root>/data/UTKFace
1616
File pattern: {age}_{gender}_{race}_{yyyymmddHHMMSS}.jpg
1717
"""
1818

1919
def __init__(self, root: str, transform: T.Compose | None = None):
2020
self.root = (
21-
Path(root) / "utkface_aligned_cropped" / "UTKFace"
21+
Path(root) / "UTKFace" # "utkface_aligned_cropped" /
2222
) # or "UTKFace" for the unaligned and varied original version.
2323
self.files = sorted(f for f in self.root.glob("*.jpg"))
2424
if not self.files:
@@ -46,10 +46,9 @@ def make_unpaired_loader(
4646
transform: T.Compose,
4747
batch_size: int = 4,
4848
num_workers: int = 1,
49-
limit: int | None = None, # per-domain cap
5049
seed: int = 42,
51-
young_max: int = 25, # 0-25
52-
old_min: int = 55, # 55+
50+
young_max: int = 28, # 18-28
51+
old_min: int = 40, # 40+
5352
):
5453
full_ds = UTKFace(root, transform)
5554

@@ -60,7 +59,7 @@ def make_unpaired_loader(
6059

6160
for i, f in enumerate(full_ds.files):
6261
age = int(f.name.split("_")[0])
63-
if age <= young_max:
62+
if age <= young_max and age >= 18:
6463
young_idx.append(i)
6564
elif age >= old_min:
6665
old_idx.append(i)
@@ -84,10 +83,10 @@ def split_indices(idxs: list[int]):
8483
part_y = split_indices(young_idx)[split].tolist()
8584
part_o = split_indices(old_idx)[split].tolist()
8685

87-
# Limit per domain
88-
if limit is not None:
89-
part_y = part_y[:limit]
90-
part_o = part_o[:limit]
86+
# same dataset length
87+
limit = min(len(part_y), len(part_o))
88+
part_y = part_y[:limit]
89+
part_o = part_o[:limit]
9190

9291
# Wrap subsets in unpaird Dataset
9392
@dataclass
@@ -107,9 +106,7 @@ def __getitem__(self, idx: int):
107106
old_ds = Subset(full_ds, part_o)
108107
paired = Unpaired(young_ds, old_ds)
109108

110-
logger.info(
111-
f"- UTK {split}: young={len(young_ds)} old={len(old_ds)}" f"(limit={limit})"
112-
)
109+
logger.info(f"- UTK {split}: young={len(young_ds)} old={len(old_ds)}")
113110
return DataLoader(
114111
paired,
115112
batch_size=batch_size,
@@ -126,11 +123,8 @@ def prepare_dataset(
126123
train_batch_size: int = 4,
127124
eval_batch_size: int = 8,
128125
num_workers: int = 2,
129-
center_crop_size: int = 256,
126+
img_size: int = 256,
130127
resize_size: int = 286,
131-
train_size: int | None = None, # None = use all
132-
val_size: int | None = None,
133-
test_size: int | None = None,
134128
seed: int = 42,
135129
):
136130
data_dir = Path(__file__).resolve().parents[2] / "data"
@@ -139,19 +133,11 @@ def prepare_dataset(
139133
# randomness
140134
train_transform = T.Compose(
141135
[
142-
T.Resize(resize_size, antialias=True),
143-
T.CenterCrop(center_crop_size),
144-
T.RandomApply(
145-
[
146-
T.RandomAffine(
147-
degrees=5, translate=(0.02, 0.02), scale=(0.97, 1.03), shear=2,
148-
interpolation=T.InterpolationMode.BILINEAR, fill=0,
149-
)
150-
],
151-
p=0.3,
152-
),
153-
T.RandomHorizontalFlip(0.5),
154-
T.ColorJitter(0.05, 0.05, 0.05, 0.02),
136+
T.ToPILImage(),
137+
T.RandomHorizontalFlip(),
138+
T.Resize((img_size + 50, img_size + 50), antialias=True),
139+
T.RandomCrop(img_size),
140+
T.RandomRotation(degrees=(0, 80)),
155141
T.ToTensor(),
156142
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
157143
]
@@ -160,8 +146,8 @@ def prepare_dataset(
160146
# deterministic
161147
eval_transform = T.Compose(
162148
[
163-
T.CenterCrop(center_crop_size),
164-
T.Resize(resize_size),
149+
T.Resize(resize_size, antialias=True),
150+
T.CenterCrop(img_size),
165151
T.ToTensor(),
166152
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
167153
]
@@ -175,7 +161,6 @@ def prepare_dataset(
175161
train_transform,
176162
train_batch_size,
177163
num_workers,
178-
train_size,
179164
seed,
180165
)
181166
val_loader = make_unpaired_loader(
@@ -184,7 +169,6 @@ def prepare_dataset(
184169
eval_transform,
185170
eval_batch_size,
186171
num_workers,
187-
val_size,
188172
seed,
189173
)
190174
test_loader = make_unpaired_loader(
@@ -193,7 +177,6 @@ def prepare_dataset(
193177
eval_transform,
194178
eval_batch_size,
195179
num_workers,
196-
test_size,
197180
seed,
198181
)
199182
logger.info("Done.")

src/aging_gan/train.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from aging_gan.utils import (
1616
set_seed,
1717
load_environ_vars,
18-
print_trainable_parameters,
1918
save_checkpoint,
2019
generate_and_save_samples,
2120
get_device,
2221
)
2322
from aging_gan.data import prepare_dataset
24-
from aging_gan.model import initialize_models, freeze_encoders, unfreeze_encoders
23+
from aging_gan.model import initialize_models
2524
from aging_gan.utils import archive_and_terminate
2625

2726
logger = logging.getLogger(__name__)
@@ -59,6 +58,12 @@ def parse_args() -> argparse.Namespace:
5958
default=32,
6059
help="Batch size per device during evaluation.",
6160
)
61+
p.add_argument(
62+
"--lambda_adv_value",
63+
type=int,
64+
default=2,
65+
help="Weight for adversarial loss",
66+
)
6267
p.add_argument(
6368
"--lambda_cyc_value",
6469
type=int,
@@ -98,24 +103,6 @@ def parse_args() -> argparse.Namespace:
98103
default=10,
99104
help="The number of example generated images to save per epoch.",
100105
)
101-
p.add_argument(
102-
"--train_size",
103-
type=int,
104-
default=3000,
105-
help="The size of train dataset to train on.",
106-
)
107-
p.add_argument(
108-
"--val_size",
109-
type=int,
110-
default=800,
111-
help="The size of validation dataset to evaluate.",
112-
)
113-
p.add_argument(
114-
"--test_size",
115-
type=int,
116-
default=800,
117-
help="The size of test dataset to evaluate.",
118-
)
119106
p.add_argument(
120107
"--num_workers",
121108
type=int,
@@ -176,13 +163,16 @@ def initialize_optimizers(cfg, G, F, DX, DY):
176163
return opt_G, opt_F, opt_DX, opt_DY
177164

178165

179-
def initialize_loss_functions(lambda_cyc_value: int = 10, lambda_id_value: int = 5):
166+
def initialize_loss_functions(
167+
lambda_adv_value: int = 2, lambda_cyc_value: int = 10, lambda_id_value: int = 5
168+
):
180169
mse = nn.MSELoss()
181170
l1 = nn.L1Loss()
171+
lambda_adv = lambda_adv_value
182172
lambda_cyc = lambda_cyc_value
183173
lambda_id = lambda_id_value
184174

185-
return mse, l1, lambda_cyc, lambda_id
175+
return mse, l1, lambda_adv, lambda_cyc, lambda_id
186176

187177

188178
def make_schedulers(cfg, opt_G, opt_F, opt_DX, opt_DY):
@@ -212,6 +202,7 @@ def perform_train_step(
212202
real_data,
213203
mse,
214204
l1,
205+
lambda_adv,
215206
lambda_cyc,
216207
lambda_id, # loss functions and loss params
217208
opt_G,
@@ -260,10 +251,10 @@ def perform_train_step(
260251
opt_F.zero_grad(set_to_none=True)
261252
# Loss 1: adversarial terms
262253
fake_test_logits = DX(fake_x) # fake x logits
263-
loss_f_adv = mse(fake_test_logits, torch.ones_like(fake_test_logits))
254+
loss_f_adv = lambda_adv * mse(fake_test_logits, torch.ones_like(fake_test_logits))
264255

265256
fake_test_logits = DY(fake_y) # fake y logits
266-
loss_g_adv = mse(fake_test_logits, torch.ones_like(fake_test_logits))
257+
loss_g_adv = lambda_adv * mse(fake_test_logits, torch.ones_like(fake_test_logits))
267258
# Loss 2: cycle terms
268259
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
269260
# Loss 3: identity terms
@@ -299,6 +290,7 @@ def evaluate_epoch(
299290
split: str, # either "val" or "test"
300291
mse,
301292
l1,
293+
lambda_adv,
302294
lambda_cyc,
303295
lambda_id, # loss functions and loss params
304296
fid_metric,
@@ -349,10 +341,14 @@ def evaluate_epoch(
349341
# ------ Evaluate Generators ------
350342
# Loss 1: adversarial terms
351343
fake_test_logits = DX(fake_x) # fake x logits
352-
loss_f_adv = mse(fake_test_logits, torch.ones_like(fake_test_logits))
344+
loss_f_adv = lambda_adv * mse(
345+
fake_test_logits, torch.ones_like(fake_test_logits)
346+
)
353347

354348
fake_test_logits = DY(fake_y) # fake y logits
355-
loss_g_adv = mse(fake_test_logits, torch.ones_like(fake_test_logits))
349+
loss_g_adv = lambda_adv * mse(
350+
fake_test_logits, torch.ones_like(fake_test_logits)
351+
)
356352
# Loss 2: cycle terms
357353
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
358354
# Loss 3: identity terms
@@ -396,6 +392,7 @@ def perform_epoch(
396392
DY,
397393
mse,
398394
l1,
395+
lambda_adv,
399396
lambda_cyc,
400397
lambda_id,
401398
opt_G,
@@ -427,6 +424,7 @@ def perform_epoch(
427424
real_data,
428425
mse,
429426
l1,
427+
lambda_adv,
430428
lambda_cyc,
431429
lambda_id, # loss functions and loss params
432430
opt_G,
@@ -469,6 +467,7 @@ def perform_epoch(
469467
"val",
470468
mse,
471469
l1,
470+
lambda_adv,
472471
lambda_cyc,
473472
lambda_id, # loss functions and loss params
474473
fid_metric, # evaluation metric
@@ -527,22 +526,19 @@ def main() -> None:
527526
cfg.train_batch_size,
528527
cfg.eval_batch_size,
529528
cfg.num_workers,
530-
train_size=cfg.train_size,
531-
val_size=cfg.val_size,
532-
test_size=cfg.test_size,
533529
seed=cfg.seed,
534530
)
535531

536532
# ---------- Models, Optimizers, Loss Functions, Schedulers Initialization ----------
537533
# Initialize the generators (G, F) and discriminators (DX, DY)
538534
G, F, DX, DY = initialize_models()
539535
# Freeze generator encoderes for training during early epochs
540-
logger.info("Parameters of generator G:")
541-
logger.info(print_trainable_parameters(G))
542-
logger.info("Freezing encoders of generators...")
543-
freeze_encoders(G, F)
544-
logger.info("Parameters of generator G after freezing:")
545-
logger.info(print_trainable_parameters(G))
536+
# logger.info("Parameters of generator G:")
537+
# logger.info(print_trainable_parameters(G))
538+
# logger.info("Freezing encoders of generators...")
539+
# freeze_encoders(G, F)
540+
# logger.info("Parameters of generator G after freezing:")
541+
# logger.info(print_trainable_parameters(G))
546542
# Initialize optimizers
547543
(
548544
opt_G,
@@ -579,8 +575,8 @@ def main() -> None:
579575
test_loader,
580576
)
581577
# Loss functions and scalers
582-
mse, l1, lambda_cyc, lambda_id = initialize_loss_functions(
583-
cfg.lambda_cyc_value, cfg.lambda_id_value
578+
mse, l1, lambda_adv, lambda_cyc, lambda_id = initialize_loss_functions(
579+
cfg.lambda_adv_value, cfg.lambda_cyc_value, cfg.lambda_id_value
584580
)
585581
# Initialize schedulers (It it important this comes AFTER wrapping optimizers in accelerator)
586582
sched_G, sched_F, sched_DX, sched_DY = make_schedulers(
@@ -596,11 +592,11 @@ def main() -> None:
596592
for epoch in range(1, cfg.num_train_epochs + 1):
597593
logger.info(f"\nEPOCH {epoch}")
598594
# after 1 full epoch, unfreeze
599-
if epoch == 2:
600-
logger.info("Unfreezing encoders of generators...")
601-
unfreeze_encoders(G, F)
602-
logger.info("Parameters of generator G after unfreezing:")
603-
logger.info(print_trainable_parameters(G))
595+
# if epoch == 2:
596+
# logger.info("Unfreezing encoders of generators...")
597+
# unfreeze_encoders(G, F)
598+
# logger.info("Parameters of generator G after unfreezing:")
599+
# logger.info(print_trainable_parameters(G))
604600

605601
val_metrics = perform_epoch(
606602
cfg,
@@ -612,6 +608,7 @@ def main() -> None:
612608
DY,
613609
mse,
614610
l1,
611+
lambda_adv,
615612
lambda_cyc,
616613
lambda_id,
617614
opt_G,

src/aging_gan/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ def load_environ_vars(wandb_project: str = "aging-gan"):
3434
logger.info(f"W&B project set to '{wandb_project}'")
3535

3636

37-
def print_trainable_parameters(model) -> str:
38-
"""
39-
Compute and return a summary of trainable vs. total parameters in a model.
40-
"""
41-
trainable_params = 0
42-
all_param = 0
43-
for _, param in model.named_parameters():
44-
all_param += param.numel()
45-
if param.requires_grad:
46-
trainable_params += param.numel()
47-
48-
return f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
37+
# def print_trainable_parameters(model) -> str:
38+
# """
39+
# Compute and return a summary of trainable vs. total parameters in a model.
40+
# """
41+
# trainable_params = 0
42+
# all_param = 0
43+
# for _, param in model.named_parameters():
44+
# all_param += param.numel()
45+
# if param.requires_grad:
46+
# trainable_params += param.numel()
47+
48+
# return f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
4949

5050

5151
def save_checkpoint(

0 commit comments

Comments
 (0)