Skip to content

Commit e27e6b1

Browse files
Merge pull request #1 from codinglabsong/codex/update-gradio_app.py-to-add-example-images
Remove binary sample images and clarify docs
2 parents a7c75c8 + 42e8201 commit e27e6b1

File tree

3 files changed

+52
-43
lines changed

3 files changed

+52
-43
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ data/sample/train/
2020
# local‐mode outputs
2121
sagemaker-*
2222
output/
23+
24+
# example images
25+
examples/

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,17 @@ This project includes an interactive Gradio app for making predictions with the
109109
python gradio_app.py
110110
```
111111
- The app will start locally and print a link (e.g., `http://127.0.0.1:7860`) to access the web UI in your browser.
112+
## Deploying on Hugging Face Spaces
113+
1. Create a new Gradio Space on [Hugging Face](https://huggingface.co/spaces).
114+
2. Upload the following files from this repo:
115+
- `gradio_app.py`
116+
- `requirements.txt`
117+
- `class_names.txt`
118+
- `config/prod.yaml`
119+
- `output/model.pth`
120+
- *(optional)* an `examples/` folder with sample images for the Gradio UI
121+
3. Commit and push to the Space. Hugging Face will build and launch the app.
122+
112123

113124
## Requirements
114125
- See `requirements.txt`

gradio_app.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,67 @@
1-
from collections import OrderedDict
2-
import torch, yaml
1+
import os
2+
import yaml
3+
import torch
34
import torch.nn as nn
45
import torch.nn.functional as F
56
from torchvision import models, transforms
67
import gradio as gr
78

9+
# Load config and class names
810
cfg = yaml.safe_load(open("config/prod.yaml"))
11+
with open("class_names.txt") as f:
12+
class_names = [line.strip() for line in f]
913

10-
# 1. Recreate model class
11-
def build_model(num_classes):
12-
"""
13-
Builds an EfficientNet-B2 model with a custom classification head.
14-
15-
Args:
16-
num_classes (int): Number of output classes for the classification head.
14+
# Build and load model
1715

18-
Returns:
19-
nn.Module: The modified EfficientNet-B2 model.
20-
"""
16+
def build_model(num_classes: int) -> nn.Module:
2117
model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.IMAGENET1K_V1)
2218
in_features = model.classifier[1].in_features
2319
model.classifier[1] = nn.Linear(in_features, num_classes)
2420
return model
2521

26-
# 2. Load class names
27-
# Load class names from file
28-
with open("class_names.txt") as f:
29-
class_names = [line.strip() for line in f]
3022

31-
# 3. Build and load the model
32-
num_classes = len(class_names)
33-
model = build_model(num_classes)
34-
35-
# If you see _orig_mod keys, strip the prefix! (Due to possibilty of saving compiled version of model during training)
36-
ckpt = torch.load("output/model.pth", map_location='cpu')
37-
new_state_dict = OrderedDict()
38-
for k, v in ckpt.items():
39-
if k.startswith('_orig_mod.'):
40-
new_state_dict[k[len('_orig_mod.'):]] = v
41-
else:
42-
new_state_dict[k] = v
23+
def load_model(path: str, num_classes: int) -> nn.Module:
24+
model = build_model(num_classes)
25+
state = torch.load(path, map_location="cpu")
26+
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
27+
model.load_state_dict(state)
28+
model.eval()
29+
return model
4330

44-
model.load_state_dict(new_state_dict)
45-
model.eval()
31+
model = load_model("output/model.pth", len(class_names))
4632

47-
# 4. Preprocessing: same as test transforms in train.py
33+
# Preprocessing (must match training)
4834
preprocess = transforms.Compose([
4935
transforms.Resize(256),
5036
transforms.CenterCrop(cfg["estimator"]["hyperparameters"]["img-size"]),
5137
transforms.ToTensor(),
52-
transforms.Normalize([0.485,0.456,0.406],
53-
[0.229,0.224,0.225])
38+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
5439
])
5540

56-
# 5. Inference function
41+
5742
def predict(image):
5843
image = preprocess(image).unsqueeze(0)
5944
with torch.no_grad():
60-
outputs = model(image) # shape: [1, 101]
61-
probs = F.softmax(outputs, dim=1).squeeze().cpu().numpy() # shape: [101]
62-
sorted_indices = probs.argsort()[::-1] # descending order
63-
result = {class_names[i]: float(probs[i]) for i in sorted_indices}
64-
return result
45+
outputs = model(image)
46+
probs = F.softmax(outputs, dim=1)[0]
47+
return {class_names[i]: float(probs[i]) for i in probs.argsort(descending=True)}
48+
49+
# Example images for the UI
50+
example_dir = "examples"
51+
if os.path.isdir(example_dir):
52+
examples = [
53+
[os.path.join(example_dir, f)]
54+
for f in os.listdir(example_dir)
55+
if f.lower().endswith((".png", ".jpg", ".jpeg"))
56+
]
57+
else:
58+
examples = None
6559

66-
# 6. Gradio app
60+
# Launch Gradio app
6761
gr.Interface(
6862
fn=predict,
6963
inputs=gr.Image(type="pil"),
70-
outputs=gr.Label(num_top_classes=101, label="Class Probabilities"),
71-
title="Food101 Classifier"
72-
).launch()
64+
outputs=gr.Label(num_top_classes=5, label="Top Classes"),
65+
title="Food101 Classifier",
66+
examples=examples,
67+
).launch()

0 commit comments

Comments
 (0)