-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp_v0.py
More file actions
122 lines (98 loc) · 4.08 KB
/
app_v0.py
File metadata and controls
122 lines (98 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from PIL import Image
from fastapi import (FastAPI, UploadFile, File, Request, HTTPException, Form)
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from src.cnn import CNN_TUMOR, im2gradCAM
from src.utils.utils import preprocess_image, CLA_label
import uuid
import os
best_model_wts = 'models/saved_model/model_cnn_best.pt'
params_model={
"shape_in": (3,256,256),
"initial_filters": 8,
"num_fc1": 100,
"dropout_rate": 0.25,
"num_classes": 4}
# 1. Load the model
model = CNN_TUMOR(params_model) # Replace with your model class
model.load_state_dict(torch.load(best_model_wts,map_location=torch.device('cpu'),weights_only=True)) # Load the saved model weights
model.eval() # Set model to evaluation mode
# 2. Predict function
def predict(image):
# Preprocess the image
image = preprocess_image(image)
# Run the model on the input image
with torch.no_grad(): # Disable gradient computation for faster inference
output = model(image)
# Optionally, apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(output, dim=1)
# Get the predicted class (assuming single-label classification)
_, predicted_class = torch.max(probabilities, 1)
return predicted_class.item(), probabilities
app = FastAPI()
UPLOAD_DIR = "frontend/static/uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
# Serve static files like CSS
app.mount("frontend/static", StaticFiles(directory="static"), name="static")
# Initialize template engine
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/result", response_class=HTMLResponse)
async def result(
request: Request,
file: UploadFile = File(None),
original_image_url: str = Form(None),
predicted_class = Form(None),
probabilities = Form(None)):
# Process the uploaded image
if file:
# Save the uploaded image
img_filename = f"{uuid.uuid4().hex}.png"
img_path = os.path.join(UPLOAD_DIR, img_filename)
image = Image.open(file.file).convert("RGB")
image.save(img_path)
original_image_url = f"frontend/static/uploads/{img_filename}"
predicted_class, probabilities = predict(image)
probabilities = f"{probabilities.tolist()[0][predicted_class]:0.3f}"
predicted_class = CLA_label[predicted_class]
elif not original_image_url:
# Raise an error if neither file nor original_image_url is provided
return HTMLResponse("Error: No image provided.", status_code=422)
# Return the result page with the uploaded image URL
return templates.TemplateResponse("result.html", {
"request": request,
"predicted_class": predicted_class,
"probabilities": probabilities,
"original_image_url": original_image_url
})
@app.post("/apply-gradcam", response_class=HTMLResponse)
async def apply_operation(
request: Request,
image_url: str = Form(...),
predicted_class = Form(...),
probabilities = Form(...)):
# Open the image based on the provided URL
img_path = os.path.join(UPLOAD_DIR, os.path.basename(image_url))
image = Image.open(img_path)
# Apply gradCAM
new_image = Image.fromarray(im2gradCAM(model, image))
# Save the new image
new_img_filename = f"{uuid.uuid4().hex}.png"
new_img_path = os.path.join(UPLOAD_DIR, new_img_filename)
new_image.save(new_img_path)
# Redirect to the operation result page
return templates.TemplateResponse("apply_gradcam.html", {
"request": request,
"original_image_url": f"{image_url}",
"new_image_url": f"frontend/static/uploads/{new_img_filename}",
"predicted_class": predicted_class,
"probabilities": probabilities,
})
# Run the application with Uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)