|
1 | | -"""Command-line interface for running a trained generator on a single image.""" |
| 1 | +"""Command-line interface for running a trained generator on a single image or evaluating the test set metrics.""" |
2 | 2 |
|
3 | 3 | import argparse |
4 | | -from pathlib import Path |
5 | | - |
| 4 | +import torchvision.transforms as T |
6 | 5 | import torch |
| 6 | +from pathlib import Path |
7 | 7 | from PIL import Image |
8 | | -import torchvision.transforms as T |
| 8 | +from accelerate import Accelerator |
| 9 | +from torchmetrics.image.fid import FrechetInceptionDistance |
| 10 | + |
9 | 11 | from aging_gan.model import initialize_models |
10 | 12 | from aging_gan.utils import get_device |
| 13 | +from aging_gan.data import prepare_dataset |
| 14 | +from aging_gan.train import evaluate_epoch, initialize_loss_functions |
11 | 15 |
|
12 | 16 |
|
13 | 17 | def parse_args() -> argparse.Namespace: |
14 | 18 | """Parse CLI arguments for running inference.""" |
15 | 19 | p = argparse.ArgumentParser( |
16 | | - description="Run one-off inference with a trained Aging-GAN generator" |
| 20 | + description="Run inference on one image or evaluate metrics on the test set." |
| 21 | + ) |
| 22 | + p.add_argument( |
| 23 | + "--mode", |
| 24 | + choices=["infer", "test"], |
| 25 | + default="infer", |
| 26 | + help="Mode to run: 'infer' for single-image inference or 'test' for test-set evaluation", |
| 27 | + ) |
| 28 | + p.add_argument( |
| 29 | + "--input", |
| 30 | + type=str, |
| 31 | + required=True, |
| 32 | + help="Path to source image (required for 'infer')", |
17 | 33 | ) |
18 | | - p.add_argument("--input", type=str, required=True, help="Path to source image") |
19 | 34 | p.add_argument( |
20 | 35 | "--output", |
21 | 36 | type=str, |
22 | 37 | default=None, |
23 | | - help="Where to save result (defaults beside input)", |
| 38 | + help="Where to save inference result (defaults beside input)", |
24 | 39 | ) |
25 | 40 | p.add_argument( |
26 | 41 | "--ckpt", |
27 | 42 | type=str, |
28 | 43 | default=str( |
29 | 44 | Path(__file__).resolve().parents[2] / "outputs/checkpoints/best.pth" |
30 | 45 | ), |
31 | | - help="Checkpoint to load (default: outputs/checkpoints/best.pth)", |
| 46 | + help="Checkpoint to load", |
32 | 47 | ) |
33 | 48 | p.add_argument( |
34 | 49 | "--direction", |
35 | 50 | choices=["young2old", "old2young"], |
36 | 51 | default="young2old", |
37 | | - help="'young2old' uses generator G, 'old2young' uses generator F", |
| 52 | + help="'young2old' uses generator G, 'old2young' uses generator F (only for 'infer')", |
| 53 | + ) |
| 54 | + p.add_argument( |
| 55 | + "--eval_batch_size", |
| 56 | + type=int, |
| 57 | + default=32, |
| 58 | + help="Batch size for test-set evaluation (only for 'test')", |
38 | 59 | ) |
39 | 60 | p.add_argument( |
40 | | - "--train_img_size", |
| 61 | + "--num_workers", |
41 | 62 | type=int, |
42 | | - default=256, |
43 | | - help="The same img_size you used for training and evaluating.", |
| 63 | + default=3, |
| 64 | + help="Number of DataLoader workers for test-set evaluation (only for 'test')", |
| 65 | + ) |
| 66 | + p.add_argument( |
| 67 | + "--lambda_adv_value", |
| 68 | + type=float, |
| 69 | + default=2.0, |
| 70 | + help="Weight for adversarial loss (only for 'test')", |
| 71 | + ) |
| 72 | + p.add_argument( |
| 73 | + "--lambda_cyc_value", |
| 74 | + type=float, |
| 75 | + default=4.0, |
| 76 | + help="Weight for cycle-consistency loss (only for 'test')", |
44 | 77 | ) |
| 78 | + p.add_argument( |
| 79 | + "--lambda_id_value", |
| 80 | + type=float, |
| 81 | + default=0.5, |
| 82 | + help="Weight for identity loss (only for 'test')", |
| 83 | + ) |
| 84 | + p.add_argument( |
| 85 | + "--seed", |
| 86 | + type=int, |
| 87 | + default=42, |
| 88 | + help="Random seed for data loading (only for 'test')", |
| 89 | + ) |
| 90 | + |
45 | 91 | return p.parse_args() |
46 | 92 |
|
47 | 93 |
|
48 | 94 | @torch.inference_mode() |
49 | 95 | def main() -> None: |
50 | | - """Load a checkpoint and generate an aged face from ``--input``.""" |
| 96 | + """Load a checkpoint and generate an aged face from ``--input`` or test on testset.""" |
51 | 97 | cfg = parse_args() |
52 | 98 | device = get_device() |
53 | 99 |
|
54 | | - # image helpers |
55 | | - preprocess = T.Compose( |
56 | | - [ |
57 | | - T.Resize( |
58 | | - (cfg.train_img_size + 50, cfg.train_img_size + 50), antialias=True |
59 | | - ), |
60 | | - T.CenterCrop(cfg.train_img_size), |
61 | | - T.ToTensor(), |
62 | | - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
63 | | - ] |
64 | | - ) |
| 100 | + # Single-image inference |
| 101 | + if cfg.mode == "infer": |
| 102 | + # image helpers |
| 103 | + preprocess = T.Compose( |
| 104 | + [ |
| 105 | + T.Resize((256 + 50, 256 + 50), antialias=True), |
| 106 | + T.CenterCrop(256), |
| 107 | + T.ToTensor(), |
| 108 | + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| 109 | + ] |
| 110 | + ) |
| 111 | + |
| 112 | + postprocess = T.Compose( |
| 113 | + [T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()] |
| 114 | + ) |
| 115 | + |
| 116 | + # Load generators and checkpoint |
| 117 | + G, F, *_ = initialize_models() # returns G, F, DX, DY |
| 118 | + ckpt = torch.load( |
| 119 | + cfg.ckpt, map_location=device |
| 120 | + ) # same keys as used in train.py |
| 121 | + if cfg.direction == "young2old": |
| 122 | + G.load_state_dict(ckpt["G"]) |
| 123 | + generator = G.eval().to(device) |
| 124 | + else: |
| 125 | + F.load_state_dict(ckpt["F"]) |
| 126 | + generator = F.eval().to(device) |
65 | 127 |
|
66 | | - postprocess = T.Compose( |
67 | | - [T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()] |
| 128 | + # Read & preprocess |
| 129 | + img_in = Image.open(cfg.input).convert("RGB") |
| 130 | + x = preprocess(img_in).unsqueeze(0).to(device) # (1,3,H,W) |
| 131 | + |
| 132 | + # Forward pages |
| 133 | + y_hat = generator(x).clamp(-1, 1) |
| 134 | + |
| 135 | + # Save |
| 136 | + img_out = postprocess(y_hat.squeeze().cpu()) |
| 137 | + out_path = ( |
| 138 | + Path(cfg.output) |
| 139 | + if cfg.output |
| 140 | + else Path(cfg.input).with_stem(Path(cfg.input).stem + f"_{cfg.direction}") |
| 141 | + ) |
| 142 | + img_out.save(out_path) |
| 143 | + print(f"Saved result -> {out_path}") |
| 144 | + return |
| 145 | + |
| 146 | + # Test-set evaluation |
| 147 | + # speedups (Enable cuDNN auto-tuner which is good for fixed input shapes) |
| 148 | + torch.backends.cudnn.benchmark = True |
| 149 | + |
| 150 | + # Prepare data loaders |
| 151 | + _, _, test_loader = prepare_dataset( |
| 152 | + cfg.eval_batch_size, cfg.eval_batch_size, cfg.num_workers, seed=cfg.seed |
68 | 153 | ) |
69 | 154 |
|
70 | | - # Load generators and checkpoint |
71 | | - G, F, *_ = initialize_models() # returns G, F, DX, DY |
72 | | - ckpt = torch.load(cfg.ckpt, map_location=device) # same keys as used in train.py |
73 | | - if cfg.direction == "young2old": |
74 | | - G.load_state_dict(ckpt["G"]) |
75 | | - generator = G.eval().to(device) |
76 | | - else: |
77 | | - F.load_state_dict(ckpt["F"]) |
78 | | - generator = F.eval().to(device) |
79 | | - |
80 | | - # Read & preprocess |
81 | | - img_in = Image.open(cfg.input).convert("RGB") |
82 | | - x = preprocess(img_in).unsqueeze(0).to(device) # (1,3,H,W) |
83 | | - |
84 | | - # Forward pages |
85 | | - y_hat = generator(x).clamp(-1, 1) |
86 | | - |
87 | | - # Save |
88 | | - img_out = postprocess(y_hat.squeeze().cpu()) |
89 | | - out_path = ( |
90 | | - Path(cfg.output) |
91 | | - if cfg.output |
92 | | - else Path(cfg.input).with_stem(Path(cfg.input).stem + f"_{cfg.direction}") |
| 155 | + # Load models and checkpoint |
| 156 | + G, F, DX, DY = initialize_models() |
| 157 | + ckpt = torch.load(cfg.ckpt, map_location="cpu") |
| 158 | + G.load_state_dict(ckpt["G"]) |
| 159 | + F.load_state_dict(ckpt["F"]) |
| 160 | + DX.load_state_dict(ckpt["DX"]) |
| 161 | + DY.load_state_dict(ckpt["DY"]) |
| 162 | + |
| 163 | + # Set up accelerator for mixed precision, parallelism, and moving to device |
| 164 | + accelerator = Accelerator(mixed_precision="fp16") |
| 165 | + G, F, DX, DY, test_loader = accelerator.prepare(G, F, DX, DY, test_loader) |
| 166 | + |
| 167 | + # Initialize loss functions and FID metric |
| 168 | + mse, l1, lambda_adv_value, lambda_cyc_value, lambda_id_value = ( |
| 169 | + initialize_loss_functions( |
| 170 | + cfg.lambda_adv_value, cfg.lambda_cyc_value, cfg.lambda_id_value |
| 171 | + ) |
93 | 172 | ) |
94 | | - img_out.save(out_path) |
95 | | - print(f"Saved result -> {out_path}") |
| 173 | + fid_metric = FrechetInceptionDistance(feature=2048, normalize=True).to( |
| 174 | + accelerator.device |
| 175 | + ) |
| 176 | + |
| 177 | + # Evaluate and print metrics |
| 178 | + # change to eval mode |
| 179 | + for m in (G, F, DX, DY): |
| 180 | + m.eval() |
| 181 | + with torch.no_grad(): |
| 182 | + metrics = evaluate_epoch( |
| 183 | + G, |
| 184 | + F, |
| 185 | + DX, |
| 186 | + DY, |
| 187 | + test_loader, |
| 188 | + "test", |
| 189 | + mse, |
| 190 | + l1, |
| 191 | + lambda_adv_value, |
| 192 | + lambda_cyc_value, |
| 193 | + lambda_id_value, |
| 194 | + fid_metric, |
| 195 | + accelerator, |
| 196 | + ) |
| 197 | + print("Test-set metrics:") |
| 198 | + for name, value in metrics.items(): |
| 199 | + print(f"{name}: {value:.6f}") |
96 | 200 |
|
97 | 201 |
|
98 | 202 | if __name__ == "__main__": |
|
0 commit comments