Skip to content

Commit 8210454

Browse files
Merge pull request to add-best-practice-docstrings-and-annotations
Add concise docstrings and type hints
2 parents 9691829 + 1906881 commit 8210454

File tree

5 files changed

+68
-35
lines changed

5 files changed

+68
-35
lines changed

src/aging_gan/data.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1+
"""Dataset and dataloader utilities for the UTKFace dataset."""
2+
13
import os
24
import logging
3-
import torch
4-
from torch.utils.data import DataLoader, Subset, Dataset
5-
import torchvision.transforms as T
65
from dataclasses import dataclass
76
from pathlib import Path
7+
from typing import Tuple
8+
9+
import torch
810
from PIL import Image
11+
from torch import Tensor
12+
from torch.utils.data import DataLoader, Dataset, Subset
13+
import torchvision.transforms as T
914

1015
logger = logging.getLogger(__name__)
1116

1217

1318
class UTKFace(Dataset):
14-
"""
15-
Assumes the unzipped aligned UTKFace images live in <root>/data/utkface_aligned_cropped/UTKFace
16-
File pattern: {age}_{gender}_{race}_{yyyymmddHHMMSS}.jpg
17-
"""
19+
"""Lightweight UTKFace dataset reader."""
1820

1921
def __init__(self, root: str, transform: T.Compose | None = None):
2022
self.root = (
@@ -29,11 +31,13 @@ def __init__(self, root: str, transform: T.Compose | None = None):
2931
self.transform = transform
3032

3133
def __len__(self) -> int:
34+
"""Return the number of images in the dataset."""
3235
return len(self.files)
3336

34-
def __getitem__(self, idx):
37+
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
38+
"""Return the transformed image and associated age label."""
3539
path = self.files[idx]
36-
age = int(path.name.split("_")[0]) # first token of file name is age
40+
age = int(path.name.split("_")[0])
3741
img = Image.open(path).convert("RGB")
3842
if self.transform:
3943
img = self.transform(img)
@@ -49,7 +53,8 @@ def make_unpaired_loader(
4953
seed: int = 42,
5054
young_max: int = 28, # 18-28
5155
old_min: int = 40, # 40+
52-
):
56+
) -> DataLoader:
57+
"""Return a dataloader yielding unpaired young/old image tuples."""
5358
full_ds = UTKFace(root, transform)
5459

5560
# Split into young, old indices
@@ -125,7 +130,8 @@ def prepare_dataset(
125130
num_workers: int = 2,
126131
img_size: int = 256,
127132
seed: int = 42,
128-
):
133+
) -> tuple[DataLoader, DataLoader, DataLoader]:
134+
"""Create train/validation/test dataloaders for UTKFace."""
129135
data_dir = Path(__file__).resolve().parents[2] / "data"
130136
os.makedirs(data_dir, exist_ok=True)
131137

src/aging_gan/inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Command-line interface for running a trained generator on a single image."""
2+
13
import argparse
24
from pathlib import Path
35

@@ -9,6 +11,7 @@
911

1012

1113
def parse_args() -> argparse.Namespace:
14+
"""Parse CLI arguments for running inference."""
1215
p = argparse.ArgumentParser(
1316
description="Run one-off inference with a trained Aging-GAN generator"
1417
)
@@ -51,6 +54,7 @@ def parse_args() -> argparse.Namespace:
5154

5255
@torch.inference_mode()
5356
def main() -> None:
57+
"""Load a checkpoint and generate an aged face from ``--input``."""
5458
cfg = parse_args()
5559
device = get_device()
5660

src/aging_gan/model.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""Model definitions for the CycleGAN-style architecture."""
2+
3+
from torch import Tensor
14
import torch.nn as nn
25
import torch.nn.functional as F
36

@@ -6,7 +9,9 @@
69

710

811
class ResidualBlock(nn.Module):
9-
def __init__(self, in_features):
12+
"""Simple residual block with two conv layers."""
13+
14+
def __init__(self, in_features: int) -> None:
1015
super().__init__()
1116

1217
conv_block = [
@@ -21,12 +26,15 @@ def __init__(self, in_features):
2126

2227
self.conv_block = nn.Sequential(*conv_block)
2328

24-
def forward(self, x):
25-
return x + self.conv_block(x) # skip connection
29+
def forward(self, x: Tensor) -> Tensor:
30+
"""Apply the residual block."""
31+
return x + self.conv_block(x)
2632

2733

2834
class Generator(nn.Module):
29-
def __init__(self, ngf, n_residual_blocks=9):
35+
"""U-Net style generator used for domain translation."""
36+
37+
def __init__(self, ngf: int, n_residual_blocks: int = 9) -> None:
3038
super().__init__()
3139

3240
# Initial convlution block
@@ -85,12 +93,15 @@ def __init__(self, ngf, n_residual_blocks=9):
8593

8694
self.model = nn.Sequential(*model)
8795

88-
def forward(self, x):
96+
def forward(self, x: Tensor) -> Tensor:
97+
"""Generate an image from ``x``."""
8998
return self.model(x)
9099

91100

92101
class Discriminator(nn.Module):
93-
def __init__(self, ndf):
102+
"""PatchGAN discriminator."""
103+
104+
def __init__(self, ndf: int) -> None:
94105
super().__init__()
95106

96107
model = [
@@ -125,13 +136,10 @@ def __init__(self, ndf):
125136

126137
self.model = nn.Sequential(*model)
127138

128-
def forward(self, x):
129-
# x: (B, 3, H, W)
130-
x = self.model(x) # (B, 1, H//8-2, W//8-2)
131-
# Average pooling and flatten
132-
return F.avg_pool2d(x, x.size()[2:]).view(
133-
x.size()[0], -1
134-
) # global average -> (B, 1, 1, 1) -> flatten to (B, 1)
139+
def forward(self, x: Tensor) -> Tensor:
140+
"""Return discriminator logits for input ``x``."""
141+
x = self.model(x)
142+
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
135143

136144

137145
# # Discriminator: PatchGAN 70x70
@@ -187,7 +195,8 @@ def initialize_models(
187195
ngf: int = 32,
188196
ndf: int = 32,
189197
n_blocks: int = 9,
190-
):
198+
) -> tuple[Generator, Generator, Discriminator, Discriminator]:
199+
"""Instantiate generators and discriminators with default sizes."""
191200
# G = smp.Unet(
192201
# encoder_name="resnet34",
193202
# encoder_weights="imagenet", # preload low-level filters

src/aging_gan/train.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ def parse_args() -> argparse.Namespace:
134134
return args
135135

136136

137-
def initialize_optimizers(cfg, G, F, DX, DY):
137+
def initialize_optimizers(
138+
cfg, G, F, DX, DY
139+
) -> tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer, optim.Optimizer]:
140+
"""Create Adam optimizers for all models."""
138141
# track all generator params (even frozen encoder params during initial training).
139142
# This would allow us to transition easily to the full fine-tuning later on by simply toggling requires_grad=True
140143
# since the optimizers already track all the parameters from the start.
@@ -170,7 +173,8 @@ def initialize_loss_functions(
170173
lambda_adv_value: float = 2.0,
171174
lambda_cyc_value: float = 10.0,
172175
lambda_id_value: float = 7.0,
173-
):
176+
) -> tuple[nn.Module, nn.Module, float, float, float]:
177+
"""Return basic CycleGAN loss functions and weights."""
174178
mse = nn.MSELoss()
175179
l1 = nn.L1Loss()
176180
lambda_adv = lambda_adv_value
@@ -180,7 +184,10 @@ def initialize_loss_functions(
180184
return mse, l1, lambda_adv, lambda_cyc, lambda_id
181185

182186

183-
def make_schedulers(cfg, opt_G, opt_F, opt_DX, opt_DY):
187+
def make_schedulers(
188+
cfg, opt_G, opt_F, opt_DX, opt_DY
189+
) -> tuple[LambdaLR, LambdaLR, LambdaLR, LambdaLR]:
190+
"""Return LR schedulers that decay linearly after half the run."""
184191
# keep lr constant constant for the first half, then linearly decay to 0
185192
n_epochs = cfg.num_train_epochs
186193
start_decay = n_epochs // 2
@@ -215,7 +222,8 @@ def perform_train_step(
215222
opt_DX,
216223
opt_DY, # discriminator optimizers
217224
accelerator,
218-
):
225+
) -> dict[str, float]:
226+
"""Run a single optimization step for generators and discriminators."""
219227
x, y = real_data
220228
# ------ Update Generators ------
221229
opt_G.zero_grad(set_to_none=True)
@@ -304,7 +312,8 @@ def evaluate_epoch(
304312
lambda_id, # loss functions and loss params
305313
fid_metric,
306314
accelerator,
307-
):
315+
) -> dict[str, float]:
316+
"""Evaluate models on ``loader`` and return averaged metrics."""
308317
metrics = {
309318
f"{split}/loss_DX": 0.0,
310319
f"{split}/loss_DY": 0.0,
@@ -416,7 +425,7 @@ def perform_epoch(
416425
epoch,
417426
accelerator,
418427
fid_metric,
419-
):
428+
) -> dict[str, float]:
420429
"""Perform a single epoch."""
421430
# TRAINING
422431
logger.info("Training...")

src/aging_gan/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Utility helpers for training and infrastructure management."""
2+
13
import os
24
import requests
35
import logging
@@ -16,7 +18,8 @@
1618
logger = logging.getLogger(__name__)
1719

1820

19-
def get_device():
21+
def get_device() -> torch.device:
22+
"""Return CUDA device if available else CPU."""
2023
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
2124

2225

@@ -30,7 +33,8 @@ def set_seed(seed: int) -> None:
3033
torch.backends.cudnn.benchmark = False # trade speed for reproducibility
3134

3235

33-
def load_environ_vars(wandb_project: str = "aging-gan"):
36+
def load_environ_vars(wandb_project: str = "aging-gan") -> None:
37+
"""Set basic environment variables needed for a run."""
3438
os.environ["WANDB_PROJECT"] = wandb_project
3539
logger.info(f"W&B project set to '{wandb_project}'")
3640

@@ -64,7 +68,7 @@ def save_checkpoint(
6468
sched_DX,
6569
sched_DY, # schedulers
6670
kind: str = "best",
67-
):
71+
) -> None:
6872
"""Overwrite the single best-ever checkpoint."""
6973
ckpt_dir = Path(__file__).resolve().parents[2] / "outputs/checkpoints"
7074
os.makedirs(ckpt_dir, exist_ok=True)
@@ -103,7 +107,8 @@ def generate_and_save_samples(
103107
epoch,
104108
device: torch.device,
105109
num_samples: int = 8,
106-
):
110+
) -> None:
111+
"""Generate ``num_samples`` images from ``generator`` and save a grid."""
107112
# grab batches until num_samples
108113
collected = []
109114
for imgs, _ in val_loader:

0 commit comments

Comments
 (0)