Skip to content

Commit 2bad8ab

Browse files
committed
added testing on test set for inference
1 parent c527b0a commit 2bad8ab

File tree

1 file changed

+155
-51
lines changed

1 file changed

+155
-51
lines changed

src/aging_gan/inference.py

Lines changed: 155 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,202 @@
1-
"""Command-line interface for running a trained generator on a single image."""
1+
"""Command-line interface for running a trained generator on a single image or evaluating the test set metrics."""
22

33
import argparse
4-
from pathlib import Path
5-
4+
import torchvision.transforms as T
65
import torch
6+
from pathlib import Path
77
from PIL import Image
8-
import torchvision.transforms as T
8+
from accelerate import Accelerator
9+
from torchmetrics.image.fid import FrechetInceptionDistance
10+
911
from aging_gan.model import initialize_models
1012
from aging_gan.utils import get_device
13+
from aging_gan.data import prepare_dataset
14+
from aging_gan.train import evaluate_epoch, initialize_loss_functions
1115

1216

1317
def parse_args() -> argparse.Namespace:
1418
"""Parse CLI arguments for running inference."""
1519
p = argparse.ArgumentParser(
16-
description="Run one-off inference with a trained Aging-GAN generator"
20+
description="Run inference on one image or evaluate metrics on the test set."
21+
)
22+
p.add_argument(
23+
"--mode",
24+
choices=["infer", "test"],
25+
default="infer",
26+
help="Mode to run: 'infer' for single-image inference or 'test' for test-set evaluation",
27+
)
28+
p.add_argument(
29+
"--input",
30+
type=str,
31+
required=True,
32+
help="Path to source image (required for 'infer')",
1733
)
18-
p.add_argument("--input", type=str, required=True, help="Path to source image")
1934
p.add_argument(
2035
"--output",
2136
type=str,
2237
default=None,
23-
help="Where to save result (defaults beside input)",
38+
help="Where to save inference result (defaults beside input)",
2439
)
2540
p.add_argument(
2641
"--ckpt",
2742
type=str,
2843
default=str(
2944
Path(__file__).resolve().parents[2] / "outputs/checkpoints/best.pth"
3045
),
31-
help="Checkpoint to load (default: outputs/checkpoints/best.pth)",
46+
help="Checkpoint to load",
3247
)
3348
p.add_argument(
3449
"--direction",
3550
choices=["young2old", "old2young"],
3651
default="young2old",
37-
help="'young2old' uses generator G, 'old2young' uses generator F",
52+
help="'young2old' uses generator G, 'old2young' uses generator F (only for 'infer')",
53+
)
54+
p.add_argument(
55+
"--eval_batch_size",
56+
type=int,
57+
default=32,
58+
help="Batch size for test-set evaluation (only for 'test')",
3859
)
3960
p.add_argument(
40-
"--train_img_size",
61+
"--num_workers",
4162
type=int,
42-
default=256,
43-
help="The same img_size you used for training and evaluating.",
63+
default=3,
64+
help="Number of DataLoader workers for test-set evaluation (only for 'test')",
65+
)
66+
p.add_argument(
67+
"--lambda_adv_value",
68+
type=float,
69+
default=2.0,
70+
help="Weight for adversarial loss (only for 'test')",
71+
)
72+
p.add_argument(
73+
"--lambda_cyc_value",
74+
type=float,
75+
default=4.0,
76+
help="Weight for cycle-consistency loss (only for 'test')",
4477
)
78+
p.add_argument(
79+
"--lambda_id_value",
80+
type=float,
81+
default=0.5,
82+
help="Weight for identity loss (only for 'test')",
83+
)
84+
p.add_argument(
85+
"--seed",
86+
type=int,
87+
default=42,
88+
help="Random seed for data loading (only for 'test')",
89+
)
90+
4591
return p.parse_args()
4692

4793

4894
@torch.inference_mode()
4995
def main() -> None:
50-
"""Load a checkpoint and generate an aged face from ``--input``."""
96+
"""Load a checkpoint and generate an aged face from ``--input`` or test on testset."""
5197
cfg = parse_args()
5298
device = get_device()
5399

54-
# image helpers
55-
preprocess = T.Compose(
56-
[
57-
T.Resize(
58-
(cfg.train_img_size + 50, cfg.train_img_size + 50), antialias=True
59-
),
60-
T.CenterCrop(cfg.train_img_size),
61-
T.ToTensor(),
62-
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
63-
]
64-
)
100+
# Single-image inference
101+
if cfg.mode == "infer":
102+
# image helpers
103+
preprocess = T.Compose(
104+
[
105+
T.Resize((256 + 50, 256 + 50), antialias=True),
106+
T.CenterCrop(256),
107+
T.ToTensor(),
108+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
109+
]
110+
)
111+
112+
postprocess = T.Compose(
113+
[T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()]
114+
)
115+
116+
# Load generators and checkpoint
117+
G, F, *_ = initialize_models() # returns G, F, DX, DY
118+
ckpt = torch.load(
119+
cfg.ckpt, map_location=device
120+
) # same keys as used in train.py
121+
if cfg.direction == "young2old":
122+
G.load_state_dict(ckpt["G"])
123+
generator = G.eval().to(device)
124+
else:
125+
F.load_state_dict(ckpt["F"])
126+
generator = F.eval().to(device)
65127

66-
postprocess = T.Compose(
67-
[T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()]
128+
# Read & preprocess
129+
img_in = Image.open(cfg.input).convert("RGB")
130+
x = preprocess(img_in).unsqueeze(0).to(device) # (1,3,H,W)
131+
132+
# Forward pages
133+
y_hat = generator(x).clamp(-1, 1)
134+
135+
# Save
136+
img_out = postprocess(y_hat.squeeze().cpu())
137+
out_path = (
138+
Path(cfg.output)
139+
if cfg.output
140+
else Path(cfg.input).with_stem(Path(cfg.input).stem + f"_{cfg.direction}")
141+
)
142+
img_out.save(out_path)
143+
print(f"Saved result -> {out_path}")
144+
return
145+
146+
# Test-set evaluation
147+
# speedups (Enable cuDNN auto-tuner which is good for fixed input shapes)
148+
torch.backends.cudnn.benchmark = True
149+
150+
# Prepare data loaders
151+
_, _, test_loader = prepare_dataset(
152+
cfg.eval_batch_size, cfg.eval_batch_size, cfg.num_workers, seed=cfg.seed
68153
)
69154

70-
# Load generators and checkpoint
71-
G, F, *_ = initialize_models() # returns G, F, DX, DY
72-
ckpt = torch.load(cfg.ckpt, map_location=device) # same keys as used in train.py
73-
if cfg.direction == "young2old":
74-
G.load_state_dict(ckpt["G"])
75-
generator = G.eval().to(device)
76-
else:
77-
F.load_state_dict(ckpt["F"])
78-
generator = F.eval().to(device)
79-
80-
# Read & preprocess
81-
img_in = Image.open(cfg.input).convert("RGB")
82-
x = preprocess(img_in).unsqueeze(0).to(device) # (1,3,H,W)
83-
84-
# Forward pages
85-
y_hat = generator(x).clamp(-1, 1)
86-
87-
# Save
88-
img_out = postprocess(y_hat.squeeze().cpu())
89-
out_path = (
90-
Path(cfg.output)
91-
if cfg.output
92-
else Path(cfg.input).with_stem(Path(cfg.input).stem + f"_{cfg.direction}")
155+
# Load models and checkpoint
156+
G, F, DX, DY = initialize_models()
157+
ckpt = torch.load(cfg.ckpt, map_location="cpu")
158+
G.load_state_dict(ckpt["G"])
159+
F.load_state_dict(ckpt["F"])
160+
DX.load_state_dict(ckpt["DX"])
161+
DY.load_state_dict(ckpt["DY"])
162+
163+
# Set up accelerator for mixed precision, parallelism, and moving to device
164+
accelerator = Accelerator(mixed_precision="fp16")
165+
G, F, DX, DY, test_loader = accelerator.prepare(G, F, DX, DY, test_loader)
166+
167+
# Initialize loss functions and FID metric
168+
mse, l1, lambda_adv_value, lambda_cyc_value, lambda_id_value = (
169+
initialize_loss_functions(
170+
cfg.lambda_adv_value, cfg.lambda_cyc_value, cfg.lambda_id_value
171+
)
93172
)
94-
img_out.save(out_path)
95-
print(f"Saved result -> {out_path}")
173+
fid_metric = FrechetInceptionDistance(feature=2048, normalize=True).to(
174+
accelerator.device
175+
)
176+
177+
# Evaluate and print metrics
178+
# change to eval mode
179+
for m in (G, F, DX, DY):
180+
m.eval()
181+
with torch.no_grad():
182+
metrics = evaluate_epoch(
183+
G,
184+
F,
185+
DX,
186+
DY,
187+
test_loader,
188+
"test",
189+
mse,
190+
l1,
191+
lambda_adv_value,
192+
lambda_cyc_value,
193+
lambda_id_value,
194+
fid_metric,
195+
accelerator,
196+
)
197+
print("Test-set metrics:")
198+
for name, value in metrics.items():
199+
print(f"{name}: {value:.6f}")
96200

97201

98202
if __name__ == "__main__":

0 commit comments

Comments
 (0)