Skip to content

Commit 23ced58

Browse files
committed
fix(Aesthetic Scorer): Remove debug and nesting response
1 parent 2b4d1cb commit 23ced58

File tree

7 files changed

+28
-89
lines changed

7 files changed

+28
-89
lines changed

app/api/admin/classify/route.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ export async function POST(req: Request) {
1313
try {
1414
const utils = utilsFactory()
1515

16-
const url = new URL(req.url)
17-
const debug = url.searchParams.get('debug') === 'true'
18-
1916
const { path: relativePath } = await req.json()
2017

2118
if (!relativePath) {
@@ -25,7 +22,7 @@ export async function POST(req: Request) {
2522
const fullPath = utils.safePublicPath(relativePath)
2623
const buffer = await fs.readFile(fullPath)
2724

28-
const classifyUrl = `http://localhost:${config.pythonPort}/classify?debug=${debug}`
25+
const classifyUrl = `http://localhost:${config.pythonPort}/classify`
2926

3027
const res = await fetch(classifyUrl, {
3128
method: 'POST',

app/api/admin/scores/route.ts

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
import { NextRequest, NextResponse } from 'next/server'
1+
import { NextResponse } from 'next/server'
22
import fs from 'node:fs/promises'
33

44
import utilsFactory from '../../../../src/lib/utils'
55
import config from '../../../../src/models/config'
66

7-
async function POST(req: Request) {
7+
export async function POST(req: Request) {
88
try {
99
const utils = utilsFactory()
1010

11-
const url = new URL(req.url)
12-
const debug = url.searchParams.get('debug') === 'true'
13-
1411
const { path: relativePath } = await req.json()
1512

1613
if (!relativePath) {
@@ -20,7 +17,7 @@ async function POST(req: Request) {
2017
const fullPath = utils.safePublicPath(relativePath)
2118
const buffer = await fs.readFile(fullPath)
2219

23-
const classifyUrl = `http://localhost:${config.pythonPort}/scores?debug=${debug}`
20+
const classifyUrl = `http://localhost:${config.pythonPort}/scores`
2421

2522
const res = await fetch(classifyUrl, {
2623
method: 'POST',
@@ -42,18 +39,3 @@ async function POST(req: Request) {
4239
return NextResponse.json({ error: err.message || 'Unexpected error' }, { status: 500 })
4340
}
4441
}
45-
46-
// Catch-all for unsupported methods
47-
function notSupported(req: NextRequest) {
48-
return NextResponse.json(`Method ${req.method} Not Allowed`, { status: 405 })
49-
}
50-
51-
export {
52-
notSupported as GET,
53-
POST,
54-
notSupported as PUT,
55-
notSupported as DELETE,
56-
notSupported as PATCH,
57-
notSupported as OPTIONS,
58-
notSupported as HEAD,
59-
}

apps/api/aesthetic.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from fastapi import Request
12
import torch
23
import torch.nn as nn
3-
from PIL import Image
44
import torchvision.transforms as T
55
import open_clip
6+
from PIL import Image
67
import logging
78
from collections import OrderedDict
9+
import io
810

911
# Set up logging once
1012
logging.basicConfig(level=logging.DEBUG)
@@ -69,21 +71,14 @@ def load_clip_model() -> tuple[torch.nn.Module, callable]:
6971
_clip_model, preprocess = load_clip_model()
7072
regression_head = load_aesthetic_head(HEAD_PATH)
7173

72-
def score_aesthetic(pil_image: Image.Image) -> float:
73-
model, _, preprocess = open_clip.create_model_and_transforms(
74-
'ViT-L-14',
75-
pretrained=None
76-
)
77-
model.eval()
78-
79-
logger.info("Load your checkpoint...")
80-
logger.info("Load your regression head...")
81-
82-
with torch.no_grad():
83-
image_tensor = preprocess(pil_image).unsqueeze(0)
84-
image_features = model.encode_image(image_tensor)
85-
image_features /= image_features.norm(dim=-1, keepdim=True)
86-
score = regression_head(image_features).item()
74+
async def score_aesthetic(req: Request) -> float:
75+
img_bytes = await req.body()
76+
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
8777

88-
return float(score)
78+
with torch.no_grad():
79+
image_tensor = preprocess(img).unsqueeze(0)
80+
image_features = _clip_model.encode_image(image_tensor)
81+
image_features /= image_features.norm(dim=-1, keepdim=True)
82+
score = regression_head(image_features).item()
8983

84+
return float(score)

apps/api/classify.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from fastapi import Request
22
import torch
33
import timm
4-
import json
5-
from PIL import Image
64
import torchvision.transforms as T
7-
import base64
5+
from PIL import Image
6+
import json
87
import io
98
import logging
109

@@ -49,12 +48,11 @@
4948
T.CenterCrop(input_size),
5049
])
5150

52-
async def classify_image(req: Request, debug: bool):
51+
async def classify_image(req: Request):
5352
img_bytes = await req.body()
5453
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
5554

5655
# Image preparation
57-
resized_img = debug_transform(img)
5856
input_tensor = transform(img).unsqueeze(0)
5957

6058
with torch.no_grad():
@@ -74,29 +72,4 @@ async def classify_image(req: Request, debug: bool):
7472
if not isinstance(results, list):
7573
results = [results]
7674

77-
if not debug:
78-
return {"predictions": results}, {}
79-
80-
# Debug details
81-
topk_debug = torch.topk(probs, k=50)
82-
top_50 = [
83-
{
84-
"label": idx_to_label.get(str(idx.item()), f"Unknown ({idx.item()})"),
85-
"score": float(score)
86-
}
87-
for idx, score in zip(topk_debug.indices[0], topk_debug.values[0])
88-
]
89-
90-
buffered = io.BytesIO()
91-
resized_img.save(buffered, format="JPEG")
92-
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
93-
94-
return {
95-
"predictions": results,
96-
"top_50": top_50,
97-
"logits_shape": list(logits.shape),
98-
"max_prob": float(probs.max()),
99-
"min_prob": float(probs.min()),
100-
"resized_image": img_base64,
101-
"confirm_order": "offline_weights_loaded",
102-
}
75+
return {"predictions": results}

apps/api/main.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from fastapi import FastAPI, Request, Query
1+
from fastapi import FastAPI, Request
22
from fastapi.responses import JSONResponse
33
import logging
44
import sys
55
import traceback
66
from aesthetic import score_aesthetic
77
from classify import classify_image
8-
from PIL import Image
9-
import io
108

119
# Setup logging once
1210
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@@ -30,22 +28,17 @@ def error_response(e: Exception):
3028
)
3129

3230
@main_py_app.post("/classify")
33-
async def classify_endpoint(req: Request, debug: bool = Query(False)):
31+
async def classify_endpoint(req: Request):
3432
try:
35-
results, debug_data = await classify_image(req, debug)
36-
return {"predictions": results} if not debug else {
37-
"predictions": results,
38-
**debug_data
39-
}
33+
results = await classify_image(req)
34+
return results
4035
except Exception as e:
4136
return error_response(e)
4237

4338
@main_py_app.post("/scores")
4439
async def score_endpoint(req: Request):
4540
try:
46-
img_bytes = await req.body()
47-
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
48-
score = score_aesthetic(img)
41+
score = await score_aesthetic(req)
4942
return {"aesthetic_score": round(score, 3)}
5043
except Exception as e:
5144
return error_response(e)

apps/api/tests/test_aesthetic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import pytest
21
from PIL import Image
32
from aesthetic import score_aesthetic
43

5-
def test_score_aesthetic_on_sample_image():
4+
async def test_score_aesthetic_on_sample_image():
65
img = Image.new("RGB", (224, 224), color="blue") # deterministic image
7-
score = score_aesthetic(img)
6+
score = await score_aesthetic(img)
87
assert isinstance(score, float)
98
assert -1.0 <= score <= 1.0

apps/api/tests/test_routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ def test_classify_route():
3737
print("DEBUG BODY:", body)
3838

3939
assert "predictions" in body
40-
predictions = body["predictions"]["predictions"]
40+
predictions = body["predictions"]
4141
assert isinstance(predictions, list), f"Expected list, got: {type(predictions)}"
4242
assert all("label" in pred and "score" in pred for pred in predictions)

0 commit comments

Comments
 (0)