Skip to content

Commit a612d30

Browse files
[Fix] Adjust Resize logic - closer to PyTorch (#76)
1 parent af8ea68 commit a612d30

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

onnxtr/transforms/base.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
55

66

7+
import math
8+
79
import numpy as np
810
from PIL import Image, ImageOps
911

@@ -37,37 +39,51 @@ def __init__(
3739
raise AssertionError("size should be either a tuple or an int")
3840

3941
def __call__(self, img: np.ndarray) -> np.ndarray:
40-
img = (img * 255).astype(np.uint8) if img.dtype != np.uint8 else img
41-
h, w = img.shape[:2] if img.ndim == 3 else img.shape[1:3]
42+
if img.dtype != np.uint8:
43+
img_pil = Image.fromarray((img * 255).clip(0, 255).astype(np.uint8))
44+
else:
45+
img_pil = Image.fromarray(img)
46+
4247
sh, sw = self.size
48+
w, h = img_pil.size
4349

4450
if not self.preserve_aspect_ratio:
45-
return np.array(Image.fromarray(img).resize((sw, sh), resample=self.interpolation))
51+
img_resized_pil = img_pil.resize((sw, sh), resample=self.interpolation)
52+
return np.array(img_resized_pil)
4653

4754
actual_ratio = h / w
4855
target_ratio = sh / sw
4956

50-
if target_ratio == actual_ratio:
51-
return np.array(Image.fromarray(img).resize((sw, sh), resample=self.interpolation))
52-
5357
if actual_ratio > target_ratio:
54-
tmp_size = (int(sh / actual_ratio), sh)
58+
new_h = sh
59+
new_w = max(int(sh / actual_ratio), 1)
5560
else:
56-
tmp_size = (sw, int(sw * actual_ratio))
61+
new_w = sw
62+
new_h = max(int(sw * actual_ratio), 1)
5763

58-
img_resized = Image.fromarray(img).resize(tmp_size, resample=self.interpolation)
59-
pad_left = pad_top = 0
60-
pad_right = sw - img_resized.width
61-
pad_bottom = sh - img_resized.height
64+
img_resized_pil = img_pil.resize((new_w, new_h), resample=self.interpolation)
65+
66+
delta_w = sw - new_w
67+
delta_h = sh - new_h
6268

6369
if self.symmetric_pad:
64-
pad_left = pad_right // 2
65-
pad_right -= pad_left
66-
pad_top = pad_bottom // 2
67-
pad_bottom -= pad_top
70+
# Symmetric padding
71+
pad_left = math.ceil(delta_w / 2)
72+
pad_right = math.floor(delta_w / 2)
73+
pad_top = math.ceil(delta_h / 2)
74+
pad_bottom = math.floor(delta_h / 2)
75+
else:
76+
# Asymmetric padding
77+
pad_left, pad_top = 0, 0
78+
pad_right, pad_bottom = delta_w, delta_h
79+
80+
img_padded_pil = ImageOps.expand(
81+
img_resized_pil,
82+
border=(pad_left, pad_top, pad_right, pad_bottom),
83+
fill=0,
84+
)
6885

69-
img_resized = ImageOps.expand(img_resized, (pad_left, pad_top, pad_right, pad_bottom))
70-
return np.array(img_resized)
86+
return np.array(img_padded_pil)
7187

7288
def __repr__(self) -> str:
7389
interpolate_str = self.interpolation

0 commit comments

Comments
 (0)