Skip to content

Commit 8b75f82

Browse files
committed
added annotations for tests
1 parent 7e7c8e2 commit 8b75f82

File tree

5 files changed

+33
-13
lines changed

5 files changed

+33
-13
lines changed

README.md

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
different lr for gen and disc
22

3-
43
# Aging GAN
54

6-
Aging GAN is a research project exploring face aging with a CycleGAN-style architecture. The code trains two U-Net generators and two PatchGAN discriminators on the CelebA dataset, preprocessing into **Young** and **Old** subsets. The generators learn to translate between these domains, effectively "aging" or "de-aging" a face image.
5+
Aging GAN is a research project exploring facial age transformation with a CycleGANstyle approach. The model trains two ResNet‑style "encoder–residual–decoder" generators and two PatchGAN discriminators on the UTKFace dataset, split into **Young** and **Old** subsets. The generators learn to translate between these domains, effectively "aging" or "de-aging" a face image.
76

8-
This repository contains training scripts, minimal utilities, and example notebooks.
7+
This repository contains training scripts, helper utilities, and inference scripts.
98

109
## Features
1110

12-
- **Unpaired Training Data**automatically split the CelebA dataset into Young vs. Old and create an unpaired `DataLoader`.
13-
- **CycleGAN Architecture**U-Net generators with ResNet encoders and PatchGAN discriminators.
14-
- **Training Utilities** – gradient clipping, learning-rate scheduling, mixed-precision via `accelerate`, and optional encoder freezing.
15-
- **Evaluation** – FID metric computation on the validation set.
11+
- **Unpaired Training Data**splits the UTKFace dataset into *Young* (18‑28) and *Old* (40+) subsets and builds an unpaired `DataLoader`.
12+
- **CycleGAN Architecture**residual U‑Net generators and PatchGAN discriminators.
13+
- **Training Utilities** – gradient clipping, separate generator/discriminator learning rates with linear decay, mixed precision via `accelerate`, and optional S3 checkpoint archiving.
14+
- **Evaluation** – FID metric computation on the validation and test sets.
1615
- **Weights & Biases Logging** – track losses and metrics during training.
17-
- **Scriptable Workflows** – run training from the command line with `scripts/run_train.sh`.
18-
19-
*Placeholders:* inference helpers, web demo, and quantitative results will be added later.
16+
- **Scriptable Workflows** – shell scripts for training and inference.
17+
- **Sample Generation** – saves example outputs after each epoch.
2018

2119
## Installation
2220

@@ -38,7 +36,9 @@ pip install -e .
3836
```
3937

4038
## Data
41-
The `prepare_dataset` function downloads CelebA automatically and creates train, validation, and test splits. Images are center‑cropped and resized to 256×256. Each split is divided into *Young* and *Old* subsets for unpaired training.
39+
Place the aligned UTKFace images under `data/utkface_aligned_cropped/UTKFace`.
40+
The `prepare_dataset` function builds deterministic train/val/test splits and applies random flipping, cropping and rotation for training.
41+
Each split is divided into *Young* and *Old* subsets for unpaired training.
4242

4343
## Training
4444
Run training with default hyper‑parameters:
@@ -54,7 +54,15 @@ python -m aging_gan.train --help
5454
```
5555

5656
## Inference
57-
The `aging_gan.inference` module is currently a stub. Once implemented, you will be able to generate aged faces from the command line using `scripts/run_inference.sh`.
57+
Generate aged faces using the command-line helper:
58+
59+
```bash
60+
bash scripts/run_inference.sh --input myface.jpg --direction young2old
61+
```
62+
The script loads `outputs/checkpoints/best.pth` by default and saves the result beside the input.
63+
64+
## AWS Utilities
65+
When running on EC2, pass `--archive_and_terminate_ec2` to automatically sync `outputs/` to S3 and terminate the instance after training.
5866

5967
## Results
6068
*Results will be added here once experiments are complete.*
@@ -65,7 +73,7 @@ The `aging_gan.inference` module is currently a stub. Once implemented, you will
6573
## Repository Structure
6674

6775
- `src/aging_gan/` – core modules (`train.py`, `model.py`, etc.)
68-
- `scripts/` – helper shell scripts for training and (placeholder) inference
76+
- `scripts/` – helper scripts for training and inference
6977
- `notebooks/` – exploratory notebooks
7078

7179
## Requirements

tests/test_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
def create_utk_dataset(tmp_path, num_per_split=6):
8+
"""Create a tiny UTKFace-style directory for testing."""
89
root = Path(tmp_path)
910
ds_root = root / "utkface_aligned_cropped" / "UTKFace"
1011
ds_root.mkdir(parents=True)
@@ -21,6 +22,7 @@ def create_utk_dataset(tmp_path, num_per_split=6):
2122

2223

2324
def test_utkface_len_and_getitem(tmp_path):
25+
"""UTKFace yields image/age pairs and correct length."""
2426
root = create_utk_dataset(tmp_path)
2527
ds = data.UTKFace(str(root))
2628
assert len(ds) == 12
@@ -30,6 +32,7 @@ def test_utkface_len_and_getitem(tmp_path):
3032

3133

3234
def test_make_unpaired_loader(tmp_path):
35+
"""Loader returns equal-sized batches of young and old images."""
3336
root = create_utk_dataset(tmp_path)
3437
loader = data.make_unpaired_loader(
3538
str(root),

tests/test_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33

44

55
def test_generator_output_shape():
6+
"""Generator preserves input image dimensions."""
67
G = model.Generator(ngf=8, n_residual_blocks=1)
78
x = torch.randn(2, 3, 64, 64)
89
y = G(x)
910
assert y.shape == x.shape
1011

1112

1213
def test_discriminator_output_shape():
14+
"""Discriminator outputs a single logit per image."""
1315
D = model.Discriminator(ndf=8)
1416
x = torch.randn(2, 3, 64, 64)
1517
out = D(x)
1618
assert out.shape == (2, 1)
1719

1820

1921
def test_initialize_models_types():
22+
"""Model initializer returns correct component classes."""
2023
G, F, DX, DY = model.initialize_models(ngf=8, ndf=8, n_blocks=1)
2124
assert isinstance(G, model.Generator)
2225
assert isinstance(F, model.Generator)

tests/test_train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
def test_parse_args_defaults(monkeypatch):
10+
"""CLI parser returns expected default arguments."""
1011
monkeypatch.setattr(sys, "argv", ["prog"])
1112
args = train.parse_args()
1213
assert args.gen_lr == 2e-4
@@ -15,6 +16,7 @@ def test_parse_args_defaults(monkeypatch):
1516

1617

1718
def test_initialize_loss_functions_defaults():
19+
"""Loss initializer provides default weights and criteria."""
1820
mse, l1, adv, cyc, ident = train.initialize_loss_functions()
1921
assert isinstance(mse, torch.nn.MSELoss)
2022
assert adv == 2.0
@@ -23,6 +25,7 @@ def test_initialize_loss_functions_defaults():
2325

2426

2527
def test_make_schedulers_decay():
28+
"""Learning rate scheduler should decrease learning rate."""
2629
cfg = SimpleNamespace(num_train_epochs=4)
2730
models = model.initialize_models(ngf=8, ndf=8, n_blocks=1)
2831
opts = [torch.optim.SGD(m.parameters(), lr=1.0) for m in models]

tests/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
def test_set_seed_reproducibility():
9+
"""Seeding should make RNGs deterministic."""
910
utils.set_seed(123)
1011
a = random.random()
1112
b = np.random.rand()
@@ -18,10 +19,12 @@ def test_set_seed_reproducibility():
1819

1920

2021
def test_get_device_cpu():
22+
"""Utility returns CPU when CUDA is unavailable."""
2123
assert utils.get_device().type == "cpu"
2224

2325

2426
def test_save_checkpoint(tmp_path):
27+
"""Checkpoint file is created on disk."""
2528
model = torch.nn.Linear(1, 1)
2629
opt = torch.optim.SGD(model.parameters(), lr=0.1)
2730
sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda _: 1)

0 commit comments

Comments
 (0)