Skip to content

Commit 2f2a7f2

Browse files
committed
Add test coverage for preprocessing modes and implement resize_mode support
- Added new tests in `tests/test_preprocessing.py` to cover stretch, crop, and pad preprocessing modes. - Refactored `serve.py` to include a reusable `preprocess_image` function that handles multiple resize modes (`pad`, `crop`, and `stretch`). - Updated `read_image_from_conn` and related methods to support dynamic resize modes based on an optional `resize_mode` parameter. - Expanded README with detailed explanations of resize modes and updated request examples.
1 parent 96ddff8 commit 2f2a7f2

File tree

3 files changed

+186
-5
lines changed

3 files changed

+186
-5
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,17 @@ Steps:
9191
{
9292
"type": "predict",
9393
"len": 12345
94+
{
95+
"type": "predict",
96+
"len": 12345,
97+
"resize_mode": "pad"
9498
}
9599
```
100+
101+
**Resize Modes (`resize_mode`):**
102+
* `pad` (Default): Pads the image with black borders to preserve aspect ratio (adds bars), then resizes to 256x256.
103+
* `crop`: Center-crops a square from the image, then resizes to 256x256.
104+
* `stretch`: Stretches the image to fit 256x256 (may distort aspect ratio).
96105
3. **Send Image**:
97106
* **Option A (Recommended):** Send a standard image file (PNG, BMP, JPG). The server uses `cv2.imdecode` to parse it automatically.
98107
* **Option B (Fallback):** Send **196,608 bytes** of raw RGB pixel data (256x256). If `len` matches exactly, it is treated as raw buffer.

scripts/serve.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,60 @@
1313
# Lock to ensure thread safety for the stateful InferenceSession
1414
session_lock = threading.Lock()
1515

16+
def preprocess_image(img, mode="pad"):
17+
"""
18+
Resizes image to 256x256 based on the mode:
19+
- stretch: simple resize (default old behavior)
20+
- crop: center crop to square, then resize
21+
- pad: pad with black to square, then resize (default new behavior)
22+
"""
23+
target_size = (256, 256)
24+
# Check if image is valid
25+
if img is None:
26+
return None
27+
28+
h, w = img.shape[:2]
29+
30+
if mode == "stretch":
31+
if (w, h) != target_size:
32+
return cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
33+
return img
34+
35+
elif mode == "crop":
36+
min_dim = min(h, w)
37+
if h != w:
38+
center_h, center_w = h // 2, w // 2
39+
half_dim = min_dim // 2
40+
start_h = max(0, center_h - half_dim)
41+
start_w = max(0, center_w - half_dim)
42+
end_h = start_h + min_dim
43+
end_w = start_w + min_dim
44+
img = img[start_h:end_h, start_w:end_w]
45+
46+
if img.shape[:2] != (256, 256):
47+
return cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
48+
return img
49+
50+
elif mode == "pad":
51+
max_dim = max(h, w)
52+
if h != w:
53+
top = (max_dim - h) // 2
54+
bottom = max_dim - h - top
55+
left = (max_dim - w) // 2
56+
right = max_dim - w - left
57+
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
58+
59+
if img.shape[:2] != (256, 256):
60+
return cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
61+
return img
62+
63+
# Fallback to pad if unknown mode
64+
if h != w:
65+
return preprocess_image(img, "pad")
66+
if img.shape[:2] != target_size:
67+
return cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
68+
return img
69+
1670
def handle_request(session, request, raw_image=None):
1771
"""Universal request handler for ZeroMQ+Pickle and TCP+JSON+RawBytes protocols."""
1872
with session_lock:
@@ -29,7 +83,7 @@ def handle_request(session, request, raw_image=None):
2983
return {"status": "error", "message": "Unknown type"}
3084

3185

32-
def read_image_from_conn(conn, expected_size=None):
86+
def read_image_from_conn(conn, expected_size=None, resize_mode='pad'):
3387
"""
3488
Reads an image from the connection.
3589
If expected_size is provided, reads exactly that many bytes.
@@ -53,8 +107,7 @@ def read_image_from_conn(conn, expected_size=None):
53107
# OpenCV loads as BGR.
54108
# We need RGB.
55109
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
56-
if img.shape[0] != 256 or img.shape[1] != 256:
57-
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
110+
img = preprocess_image(img, resize_mode)
58111
return img
59112
except Exception:
60113
pass
@@ -143,7 +196,7 @@ def read_image_from_conn(conn, expected_size=None):
143196
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
144197

145198
if actual_width != 256 or actual_height != 256:
146-
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
199+
img = preprocess_image(img, resize_mode)
147200

148201
return img
149202
except struct.error:
@@ -233,7 +286,8 @@ def run_tcp_server(session, port):
233286
img = None
234287
if req.get("type") == "predict":
235288
expected_len = req.get("len")
236-
img = read_image_from_conn(conn, expected_size=expected_len)
289+
resize_mode = req.get("resize_mode", "pad")
290+
img = read_image_from_conn(conn, expected_size=expected_len, resize_mode=resize_mode)
237291
if img is None:
238292
print("Incomplete or invalid image data received")
239293
break

tests/test_preprocessing.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import sys
2+
import os
3+
import pytest
4+
from unittest.mock import MagicMock, call
5+
6+
# Add scripts to path
7+
sys.path.append(os.path.join(os.path.dirname(__file__), "../scripts"))
8+
9+
# Import serve (mocks are already applied by conftest.py)
10+
try:
11+
import serve
12+
except ImportError:
13+
serve = None
14+
15+
import cv2
16+
import numpy as np
17+
18+
@pytest.mark.skipif(serve is None, reason="Dependencies missing")
19+
class TestPreprocessing:
20+
21+
def setup_method(self):
22+
# Reset mocks before each test
23+
cv2.reset_mock()
24+
# Setup common mock behavior
25+
# cv2.resize returns a new mock
26+
cv2.resize.return_value = MagicMock(shape=(256, 256, 3))
27+
cv2.copyMakeBorder.return_value = MagicMock(shape=(256, 256, 3))
28+
29+
def test_preprocess_stretch(self):
30+
# Setup Mock Image
31+
img = MagicMock()
32+
img.shape = (200, 100, 3) # Height 200, Width 100
33+
34+
# Call
35+
res = serve.preprocess_image(img, "stretch")
36+
37+
# Verify
38+
cv2.resize.assert_called_once()
39+
args, kwargs = cv2.resize.call_args
40+
assert args[0] == img # First arg is image
41+
assert args[1] == (256, 256) # Target size
42+
assert kwargs.get('interpolation') == cv2.INTER_AREA
43+
44+
def test_preprocess_stretch_no_op(self):
45+
# If already 256x256
46+
img = MagicMock()
47+
img.shape = (256, 256, 3)
48+
49+
res = serve.preprocess_image(img, "stretch")
50+
51+
# Should return original image without resize
52+
cv2.resize.assert_not_called()
53+
assert res == img
54+
55+
def test_preprocess_crop(self):
56+
# 100x200 (Height 100, Width 200)
57+
img = MagicMock()
58+
img.shape = (100, 200, 3)
59+
# Slicing returns a new mock
60+
sliced_img = MagicMock()
61+
sliced_img.shape = (100, 100, 3) # After crop it should be square
62+
img.__getitem__.return_value = sliced_img
63+
64+
res = serve.preprocess_image(img, "crop")
65+
66+
# Verify Slicing (Center Crop)
67+
# min_dim = 100. Center w=100. start_w = 50. end_w = 150.
68+
# Img should be sliced [0:100, 50:150]
69+
# Since we can't easily check slice args on __getitem__ with simple mocks without complex setup,
70+
# we focus on the fact that it was sliced and then resized.
71+
72+
img.__getitem__.assert_called()
73+
74+
# And then resized
75+
cv2.resize.assert_called_once()
76+
args, kwargs = cv2.resize.call_args
77+
assert args[0] == sliced_img # Should resize the sliced result
78+
assert args[1] == (256, 256)
79+
80+
def test_preprocess_pad(self):
81+
# 200x100 (Height 200, Width 100)
82+
img = MagicMock()
83+
img.shape = (200, 100, 3)
84+
85+
padded_img = MagicMock()
86+
padded_img.shape = (200, 200, 3) # Square after padding
87+
cv2.copyMakeBorder.return_value = padded_img
88+
89+
res = serve.preprocess_image(img, "pad")
90+
91+
# Verify Padding
92+
# max_dim = 200. Padding on left/right.
93+
# Height 200. Width 100.
94+
# top=0, bottom=0.
95+
# left=(200-100)//2=50. right=50.
96+
97+
cv2.copyMakeBorder.assert_called_once_with(
98+
img, 0, 0, 50, 50, cv2.BORDER_CONSTANT, value=[0, 0, 0]
99+
)
100+
101+
# And then resized
102+
cv2.resize.assert_called_once()
103+
args, kwargs = cv2.resize.call_args
104+
assert args[0] == padded_img
105+
assert args[1] == (256, 256)
106+
107+
def test_preprocess_default(self):
108+
# Default -> Pad
109+
img = MagicMock()
110+
img.shape = (200, 100, 3)
111+
112+
padded_img = MagicMock()
113+
padded_img.shape = (200, 200, 3)
114+
cv2.copyMakeBorder.return_value = padded_img
115+
116+
res = serve.preprocess_image(img) # No mode
117+
118+
cv2.copyMakeBorder.assert_called_once() # Should use pad logic

0 commit comments

Comments
 (0)