Skip to content

Commit bde72ca

Browse files
committed
added smoke tests
1 parent 47b8c10 commit bde72ca

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

.github/workflows/smoke-tests.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Smoke Tests
2+
3+
on:
4+
push:
5+
branches: [ "main" ]
6+
7+
jobs:
8+
smoke-tests:
9+
runs-on: ubuntu-latest
10+
env:
11+
WANDB_MODE: offline
12+
13+
steps:
14+
- name: Checkout
15+
uses: actions/checkout@v4
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v5
19+
with:
20+
python-version: '3.10'
21+
cache: 'pip'
22+
23+
- name: Install dependencies
24+
run: |
25+
python -m pip install --upgrade pip
26+
pip install -e .
27+
pip install -r requirements-dev.txt
28+
29+
- name: Run formatting, linting
30+
run: |
31+
pre-commit install
32+
pre-commit run --all-files
33+
34+
- name: Run tests
35+
run: pytest

tests/test_smoke_training.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
from types import SimpleNamespace
3+
import torchvision.transforms as T
4+
from aging_gan import data, model, train
5+
from test_data import create_utk_dataset
6+
7+
8+
class DummyAccelerator:
9+
def __init__(self):
10+
self.device = torch.device("cpu")
11+
12+
def autocast(self):
13+
from contextlib import nullcontext
14+
15+
return nullcontext()
16+
17+
def backward(self, loss):
18+
loss.backward()
19+
20+
def clip_grad_norm_(self, params, max_norm):
21+
torch.nn.utils.clip_grad_norm_(params, max_norm)
22+
23+
24+
class DummyFID:
25+
def reset(self):
26+
pass
27+
28+
def update(self, *args, **kwargs):
29+
pass
30+
31+
def compute(self):
32+
return torch.tensor(0.0)
33+
34+
35+
def test_smoke_training(tmp_path, monkeypatch):
36+
root = create_utk_dataset(tmp_path)
37+
transform = T.Compose([T.ToTensor()])
38+
train_loader = data.make_unpaired_loader(
39+
str(root),
40+
"train",
41+
transform,
42+
batch_size=2,
43+
num_workers=1,
44+
seed=0,
45+
young_max=23,
46+
old_min=40,
47+
)
48+
val_loader = data.make_unpaired_loader(
49+
str(root),
50+
"valid",
51+
transform,
52+
batch_size=2,
53+
num_workers=1,
54+
seed=0,
55+
young_max=23,
56+
old_min=40,
57+
)
58+
59+
G, F, DX, DY = model.initialize_models(ngf=4, ndf=4, n_blocks=1)
60+
opt_cfg = SimpleNamespace(
61+
gen_lr=1e-3, disc_lr=1e-3, weight_decay=0.0, num_train_epochs=1
62+
)
63+
opt_G, opt_F, opt_DX, opt_DY = train.initialize_optimizers(opt_cfg, G, F, DX, DY)
64+
sched_G, sched_F, sched_DX, sched_DY = train.make_schedulers(
65+
opt_cfg, opt_G, opt_F, opt_DX, opt_DY
66+
)
67+
mse, l1, adv, cyc, ident = train.initialize_loss_functions()
68+
accelerator = DummyAccelerator()
69+
fid = DummyFID()
70+
monkeypatch.setattr(train, "wandb", SimpleNamespace(log=lambda *a, **k: None))
71+
monkeypatch.setattr(train, "generate_and_save_samples", lambda *a, **k: None)
72+
cfg = SimpleNamespace(steps_for_logging_metrics=1, num_sample_generations_to_save=1)
73+
metrics = train.perform_epoch(
74+
cfg,
75+
train_loader,
76+
val_loader,
77+
G,
78+
F,
79+
DX,
80+
DY,
81+
mse,
82+
l1,
83+
adv,
84+
cyc,
85+
ident,
86+
opt_G,
87+
opt_F,
88+
opt_DX,
89+
opt_DY,
90+
sched_G,
91+
sched_F,
92+
sched_DX,
93+
sched_DY,
94+
0,
95+
accelerator,
96+
fid,
97+
)
98+
assert "val/loss_gen_total" in metrics

0 commit comments

Comments
 (0)