Skip to content

Commit 6e90199

Browse files
authored
Merge pull request #881 from bmahabirbu/object-fix
fix: add new file format support (PNG)
2 parents 63a83a7 + 23cd8f5 commit 6e90199

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

model_servers/object_detection_python/src/object_detection_server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import io
1010
import shutil
11+
from typing import Optional
1112

1213

1314
app = FastAPI()
@@ -28,7 +29,7 @@
2829
model = AutoModelForObjectDetection.from_pretrained(model, revision=revision)
2930

3031
class Item(BaseModel):
31-
image: bytes
32+
image: bytes
3233

3334
@app.get("/health")
3435
def tests_alive():
@@ -39,7 +40,7 @@ def detection(item: Item):
3940
b64_image = item.image
4041
b64_image = base64.b64decode(b64_image)
4142
bytes_io = io.BytesIO(b64_image)
42-
image = Image.open(bytes_io)
43+
image = Image.open(bytes_io).convert("RGB")
4344
inputs = processor(images=image, return_tensors="pt")
4445
outputs = model(**inputs)
4546
target_sizes = torch.tensor([image.size[::-1]])
@@ -54,8 +55,9 @@ def detection(item: Item):
5455
label_confidence = f"Detected {model.config.id2label[label.item()]} with confidence {round(score.item(), 3)}"
5556
scores.append(label_confidence)
5657

57-
bytes_io = io.BytesIO()
58-
image.save(bytes_io, "JPEG")
58+
bytes_io = io.BytesIO()
59+
# Convert image format to PNG
60+
image.save(bytes_io, format="PNG")
5961
img_bytes = bytes_io.getvalue()
6062
img_bytes = base64.b64encode(img_bytes).decode('utf-8')
6163
return {'image': img_bytes, "boxes": scores}

recipes/computer_vision/object_detection/app/object_detection_client.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,22 @@
2323
scale_factor = 0.20
2424
img = img.resize((int(img.width * scale_factor) ,
2525
int(img.height * scale_factor)))
26-
window.image(img, use_column_width=True)
27-
# convert PIL image into bytes for post request
26+
window.image(img, use_column_width=True)
27+
28+
# Convert image to RGB
29+
img = img.convert("RGB")
30+
# Encode image
2831
bytes_io = io.BytesIO()
29-
if img.mode in ("RGBA", "P"):
30-
img = img.convert("RGB")
31-
img.save(bytes_io, "JPEG")
32+
# Convert image format to PNG
33+
img.save(bytes_io, format="PNG")
3234
img_bytes = bytes_io.getvalue()
3335
b64_image = base64.b64encode(img_bytes).decode('utf-8')
34-
data = {'image': b64_image}
36+
37+
# Prepare payload
38+
data = {
39+
'image': b64_image,
40+
}
41+
3542
response = requests.post(f'{endpoint}/detection', headers=headers,json=data, verify=False)
3643
# parse response and display outputs
3744
response_json = response.json()

0 commit comments

Comments
 (0)