11from fastapi import Request
22import torch
33import timm
4- import json
5- from PIL import Image
64import torchvision .transforms as T
7- import base64
5+ from PIL import Image
6+ import json
87import io
98import logging
109
4948 T .CenterCrop (input_size ),
5049])
5150
52- async def classify_image (req : Request , debug : bool ):
51+ async def classify_image (req : Request ):
5352 img_bytes = await req .body ()
5453 img = Image .open (io .BytesIO (img_bytes )).convert ("RGB" )
5554
5655 # Image preparation
57- resized_img = debug_transform (img )
5856 input_tensor = transform (img ).unsqueeze (0 )
5957
6058 with torch .no_grad ():
@@ -74,29 +72,4 @@ async def classify_image(req: Request, debug: bool):
7472 if not isinstance (results , list ):
7573 results = [results ]
7674
77- if not debug :
78- return {"predictions" : results }, {}
79-
80- # Debug details
81- topk_debug = torch .topk (probs , k = 50 )
82- top_50 = [
83- {
84- "label" : idx_to_label .get (str (idx .item ()), f"Unknown ({ idx .item ()} )" ),
85- "score" : float (score )
86- }
87- for idx , score in zip (topk_debug .indices [0 ], topk_debug .values [0 ])
88- ]
89-
90- buffered = io .BytesIO ()
91- resized_img .save (buffered , format = "JPEG" )
92- img_base64 = base64 .b64encode (buffered .getvalue ()).decode ("utf-8" )
93-
94- return {
95- "predictions" : results ,
96- "top_50" : top_50 ,
97- "logits_shape" : list (logits .shape ),
98- "max_prob" : float (probs .max ()),
99- "min_prob" : float (probs .min ()),
100- "resized_image" : img_base64 ,
101- "confirm_order" : "offline_weights_loaded" ,
102- }
75+ return {"predictions" : results }
0 commit comments