Skip to content

Commit c527b0a

Browse files
committed
added flag to change image size in inference; added gradio
1 parent 8ffceee commit c527b0a

File tree

4 files changed

+114
-15
lines changed

4 files changed

+114
-15
lines changed

app.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
Gradio demo for Aging-GAN: upload a face, choose direction, and get an aged or rejuvenated output.
3+
"""
4+
5+
import gradio as gr
6+
import torch
7+
from pathlib import Path
8+
from PIL import Image
9+
import torchvision.transforms as T
10+
11+
from aging_gan.model import initialize_models
12+
13+
14+
# Utils
15+
def get_device() -> torch.device:
16+
"""Return CUDA device if available else CPU."""
17+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
18+
19+
20+
# Transforms
21+
preprocess = T.Compose(
22+
[
23+
T.Resize((256 + 50, 256 + 50), antialias=True),
24+
T.CenterCrop(256),
25+
T.ToTensor(),
26+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
27+
]
28+
)
29+
30+
postprocess = T.Compose([T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()])
31+
32+
# Load models & checkpoint once
33+
device = get_device()
34+
35+
# initialize G (young→old) and F (old→young)
36+
G, F, _, _ = initialize_models()
37+
ckpt_path = Path("outputs/checkpoints/epoch_0030.pth")
38+
ckpt = torch.load(ckpt_path, map_location=device)
39+
40+
G.load_state_dict(ckpt["G"])
41+
F.load_state_dict(ckpt["F"])
42+
G.eval().to(device)
43+
F.eval().to(device)
44+
45+
46+
# Inference function
47+
def infer(image: Image.Image, direction: str) -> Image.Image:
48+
"""
49+
Run a single forward pass through the chosen generator.
50+
"""
51+
# preprocess
52+
x = preprocess(image).unsqueeze(0).to(device) # (1,3,256,256)
53+
54+
# generate
55+
with torch.inference_mode():
56+
if direction == "young2old":
57+
y_hat = G(x)
58+
else:
59+
y_hat = F(x)
60+
y_hat = torch.clamp(y_hat, -1, 1)
61+
62+
# postprocess & return PIL image
63+
out = postprocess(y_hat.squeeze(0).cpu())
64+
return out
65+
66+
67+
# Launch Gradio
68+
demo = gr.Interface(
69+
fn=infer,
70+
inputs=[
71+
gr.Image(type="pil", label="Input Face"),
72+
gr.Radio(
73+
choices=["young2old", "old2young"],
74+
value="young2old",
75+
label="Transformation Direction",
76+
),
77+
],
78+
outputs=gr.Image(type="pil", label="Output Face"),
79+
title="Aging-GAN Demo",
80+
description=(
81+
"Upload a portrait, select “young2old” to age it or “old2young” to rejuvenate. "
82+
"Powered by a ResNet-style CycleGAN generator. "
83+
"TIP: Upload close-up photos of the face similar to ones in the Github README examples."
84+
),
85+
allow_flagging="never",
86+
)
87+
88+
if __name__ == "__main__":
89+
demo.launch()

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ torchmetrics[image]
1010
wandb
1111
numpy
1212
python-dotenv
13-
boto3
13+
boto3
14+
gradio

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"python-dotenv",
1010
"numpy",
1111
"wandb",
12-
# "gradio",
12+
"gradio",
1313
"ipykernel",
1414
"matplotlib",
1515
"torch",

src/aging_gan/inference.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,37 @@ def parse_args() -> argparse.Namespace:
3636
default="young2old",
3737
help="'young2old' uses generator G, 'old2young' uses generator F",
3838
)
39+
p.add_argument(
40+
"--train_img_size",
41+
type=int,
42+
default=256,
43+
help="The same img_size you used for training and evaluating.",
44+
)
3945
return p.parse_args()
4046

4147

42-
# image helpers
43-
preprocess = T.Compose(
44-
[
45-
T.Resize(256, interpolation=Image.BICUBIC),
46-
T.CenterCrop(256),
47-
T.ToTensor(),
48-
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
49-
]
50-
)
51-
52-
postprocess = T.Compose([T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()])
53-
54-
5548
@torch.inference_mode()
5649
def main() -> None:
5750
"""Load a checkpoint and generate an aged face from ``--input``."""
5851
cfg = parse_args()
5952
device = get_device()
6053

54+
# image helpers
55+
preprocess = T.Compose(
56+
[
57+
T.Resize(
58+
(cfg.train_img_size + 50, cfg.train_img_size + 50), antialias=True
59+
),
60+
T.CenterCrop(cfg.train_img_size),
61+
T.ToTensor(),
62+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
63+
]
64+
)
65+
66+
postprocess = T.Compose(
67+
[T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()]
68+
)
69+
6170
# Load generators and checkpoint
6271
G, F, *_ = initialize_models() # returns G, F, DX, DY
6372
ckpt = torch.load(cfg.ckpt, map_location=device) # same keys as used in train.py

0 commit comments

Comments
 (0)