Skip to content

Commit 7e7c8e2

Browse files
committed
added tests
1 parent abaf6b9 commit 7e7c8e2

File tree

6 files changed

+158
-0
lines changed

6 files changed

+158
-0
lines changed

src/aging_gan/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import sys
2+
from pathlib import Path
3+
4+
# Add src directory to path for tests
5+
SRC_ROOT = Path(__file__).resolve().parents[1] / "src"
6+
sys.path.insert(0, str(SRC_ROOT))

tests/test_data.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from pathlib import Path
2+
from PIL import Image
3+
import torchvision.transforms as T
4+
from aging_gan import data
5+
6+
7+
def create_utk_dataset(tmp_path, num_per_split=6):
8+
root = Path(tmp_path)
9+
ds_root = root / "utkface_aligned_cropped" / "UTKFace"
10+
ds_root.mkdir(parents=True)
11+
# create young images ages 18..(18+num_per_split-1)
12+
for i in range(num_per_split):
13+
age = 18 + i
14+
img = Image.new("RGB", (32, 32), color=(i, i, i))
15+
img.save(ds_root / f"{age}_0_0_202001010000{i}.jpg")
16+
for i in range(num_per_split):
17+
age = 40 + i
18+
img = Image.new("RGB", (32, 32), color=(i, i, i))
19+
img.save(ds_root / f"{age}_0_0_202001010100{i}.jpg")
20+
return root
21+
22+
23+
def test_utkface_len_and_getitem(tmp_path):
24+
root = create_utk_dataset(tmp_path)
25+
ds = data.UTKFace(str(root))
26+
assert len(ds) == 12
27+
img, age = ds[0]
28+
assert isinstance(age, int)
29+
assert isinstance(img, Image.Image)
30+
31+
32+
def test_make_unpaired_loader(tmp_path):
33+
root = create_utk_dataset(tmp_path)
34+
loader = data.make_unpaired_loader(
35+
str(root),
36+
"train",
37+
T.Compose([T.ToTensor()]),
38+
batch_size=2,
39+
num_workers=1,
40+
seed=0,
41+
young_max=23,
42+
old_min=40,
43+
)
44+
x, y = next(iter(loader))
45+
assert x.shape == y.shape
46+
assert x.shape[0] == 2

tests/test_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
from aging_gan import model
3+
4+
5+
def test_generator_output_shape():
6+
G = model.Generator(ngf=8, n_residual_blocks=1)
7+
x = torch.randn(2, 3, 64, 64)
8+
y = G(x)
9+
assert y.shape == x.shape
10+
11+
12+
def test_discriminator_output_shape():
13+
D = model.Discriminator(ndf=8)
14+
x = torch.randn(2, 3, 64, 64)
15+
out = D(x)
16+
assert out.shape == (2, 1)
17+
18+
19+
def test_initialize_models_types():
20+
G, F, DX, DY = model.initialize_models(ngf=8, ndf=8, n_blocks=1)
21+
assert isinstance(G, model.Generator)
22+
assert isinstance(F, model.Generator)
23+
assert isinstance(DX, model.Discriminator)
24+
assert isinstance(DY, model.Discriminator)

tests/test_train.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from types import SimpleNamespace
3+
from aging_gan import train, model
4+
5+
6+
import sys
7+
8+
9+
def test_parse_args_defaults(monkeypatch):
10+
monkeypatch.setattr(sys, "argv", ["prog"])
11+
args = train.parse_args()
12+
assert args.gen_lr == 2e-4
13+
assert args.disc_lr == 1e-4
14+
assert args.num_train_epochs == 100
15+
16+
17+
def test_initialize_loss_functions_defaults():
18+
mse, l1, adv, cyc, ident = train.initialize_loss_functions()
19+
assert isinstance(mse, torch.nn.MSELoss)
20+
assert adv == 2.0
21+
assert cyc == 10.0
22+
assert ident == 7.0
23+
24+
25+
def test_make_schedulers_decay():
26+
cfg = SimpleNamespace(num_train_epochs=4)
27+
models = model.initialize_models(ngf=8, ndf=8, n_blocks=1)
28+
opts = [torch.optim.SGD(m.parameters(), lr=1.0) for m in models]
29+
sched_G, _, _, _ = train.make_schedulers(cfg, *opts)
30+
sched_G.step() # epoch 0
31+
assert opts[0].param_groups[0]["lr"] == 1.0
32+
sched_G.step(3)
33+
assert opts[0].param_groups[0]["lr"] < 1.0

tests/test_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import random
2+
import numpy as np
3+
import torch
4+
from pathlib import Path
5+
from aging_gan import utils
6+
7+
8+
def test_set_seed_reproducibility():
9+
utils.set_seed(123)
10+
a = random.random()
11+
b = np.random.rand()
12+
c = torch.rand(1)
13+
14+
utils.set_seed(123)
15+
assert random.random() == a
16+
assert np.random.rand() == b
17+
assert torch.allclose(torch.rand(1), c)
18+
19+
20+
def test_get_device_cpu():
21+
assert utils.get_device().type == "cpu"
22+
23+
24+
def test_save_checkpoint(tmp_path):
25+
model = torch.nn.Linear(1, 1)
26+
opt = torch.optim.SGD(model.parameters(), lr=0.1)
27+
sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda _: 1)
28+
29+
utils.save_checkpoint(
30+
1,
31+
model,
32+
model,
33+
model,
34+
model,
35+
opt,
36+
opt,
37+
opt,
38+
opt,
39+
sched,
40+
sched,
41+
sched,
42+
sched,
43+
kind="best",
44+
)
45+
ckpt_file = (
46+
Path(__file__).resolve().parents[1] / "outputs" / "checkpoints" / "best.pth"
47+
)
48+
assert ckpt_file.exists()
49+
ckpt_file.unlink()

0 commit comments

Comments
 (0)