Skip to content

Commit 9c08788

Browse files
authored
Merge pull request #11 from lucasalvaa/fast-API
Fast API and Dockerfile
2 parents 6982c80 + 991df1d commit 9c08788

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

Dockerfile

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Python light version
2+
FROM python:3.12-slim
3+
4+
ENV PYTHONDONTWRITEBYTECODE=1
5+
ENV PYTHONUNBUFFERED=1
6+
7+
WORKDIR /app
8+
9+
RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \
10+
libopenjp2-7 \
11+
&& rm -rf /var/lib/apt/lists/*
12+
13+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel
14+
15+
# Python dependencies
16+
RUN pip install --no-cache-dir \
17+
fastapi \
18+
uvicorn \
19+
python-multipart \
20+
"pillow>=11.0.0"
21+
22+
# Pytorch
23+
RUN pip install --no-cache-dir \
24+
torch torchvision \
25+
--index-url https://download.pytorch.org/whl/cpu
26+
27+
RUN useradd -m appuser
28+
USER appuser
29+
30+
# Copy script and model weights
31+
COPY src/api.py .
32+
COPY --chown=appuser:appuser pipeline3/effnet_s/finetuned/model.pth ./weights/model.pth
33+
34+
EXPOSE 8080
35+
CMD ["python", "api.py"]

ruff.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ ignore = [
2020
"D100", # Spesso i docstring nei file __init__.py sono ridondanti
2121
"ANN101", # Non serve annotare 'self'
2222
"ANN102", # Non serve annotare 'cls'
23+
"N812",
24+
"B008"
2325
]
2426

2527
[lint.mccabe]

src/api.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import io
2+
import time
3+
from typing import Dict, Tuple
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import uvicorn
9+
from fastapi import FastAPI, File, UploadFile
10+
from PIL import Image
11+
from pydantic import BaseModel
12+
from torchvision import models, transforms
13+
14+
MODEL_PATH = "weights/model.pth" # Container path
15+
MODEL_VERSION = "efficientnet_v2_s"
16+
17+
CLASSES = [
18+
"demodicosis",
19+
"dermatitis",
20+
"fungal_infections",
21+
"healthy",
22+
"hypersensitivity",
23+
"ringworm",
24+
]
25+
26+
app = FastAPI(title="Dog Skin Disease Classifier")
27+
28+
29+
def load_model() -> Tuple[nn.Module, torch.device]:
30+
"""Load the EfficientNetV2_S model and its fine-tuned weights.
31+
32+
The architecture is modified by replacing the final classifier
33+
to adapt it to the problem-specific number of classes.
34+
35+
Returns:
36+
Tuple[nn.Module, torch.device]: The loaded model and the device (CPU).
37+
38+
"""
39+
# Initialize the model
40+
model = getattr(models, MODEL_VERSION)(weights=None)
41+
42+
# In EfficientNetV2, classifier is accessible through model.classifier[1]
43+
# Structure: [0] Dropout, [1] Linear
44+
n_inputs = model.classifier[1].in_features
45+
model.classifier[1] = nn.Linear(n_inputs, len(CLASSES))
46+
47+
# state_dict loading
48+
device = torch.device("cpu")
49+
state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=True)
50+
51+
model.load_state_dict(state_dict)
52+
model.eval()
53+
return model, device
54+
55+
56+
# Model is loaded when the application starts
57+
model, device = load_model()
58+
59+
preprocess = transforms.Compose(
60+
[
61+
transforms.Resize((224, 224)),
62+
transforms.ToTensor(),
63+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
64+
]
65+
)
66+
67+
68+
class PredictionResponse(BaseModel):
69+
"""Response pattern for model prediction.
70+
71+
Attributes:
72+
label_name (str): Name of the predicted class.
73+
confidence_score (float): Probability associated with the predicted class.
74+
inference_time_ms (float): Time taken for inference in milliseconds.
75+
model_version_id (str): Model version identifier.
76+
77+
"""
78+
79+
label_name: str
80+
confidence_score: float
81+
inference_time_ms: float
82+
model_version_id: str
83+
84+
85+
@app.post("/predict", response_model=PredictionResponse)
86+
async def predict(file: UploadFile = File(...)) -> Dict:
87+
"""Upon receiving an image, it performs preprocessing and returns the prediction.
88+
89+
Args:
90+
file (UploadFile): Image file uploaded via POST request.
91+
92+
Returns:
93+
Dict: Classification result with score and execution time.
94+
95+
"""
96+
start_time = time.perf_counter()
97+
98+
# Read the uploaded image
99+
content = await file.read()
100+
image = Image.open(io.BytesIO(content)).convert("RGB")
101+
102+
# Prepare the image to be fed as input to the model
103+
input_tensor = preprocess(image).unsqueeze(0).to(device)
104+
105+
# Inference without gradient calculation
106+
with torch.no_grad():
107+
outputs = model(input_tensor)
108+
probabilities = F.softmax(outputs[0], dim=0)
109+
110+
# Class and confidence score extraction
111+
conf, idx = torch.max(probabilities, 0)
112+
label = CLASSES[idx.item()]
113+
114+
return {
115+
"label_name": label,
116+
"confidence_score": round(float(conf), 4),
117+
"inference_time_ms": round((time.perf_counter() - start_time) * 1000, 2),
118+
"model_version_id": MODEL_VERSION,
119+
}
120+
121+
122+
if __name__ == "__main__":
123+
uvicorn.run(app, host="0.0.0.0", port=8080)

0 commit comments

Comments
 (0)