Skip to content

Commit 143a7fe

Browse files
artemorloffZenMan123baberabb
authored
Adding resize images support (#2958)
* first version of image resizing * fixed bug * clean up `resize_image` --------- Co-authored-by: Artem Safin <[email protected]> Co-authored-by: Baber <[email protected]>
1 parent 2cfdd0a commit 143a7fe

File tree

3 files changed

+146
-2
lines changed

3 files changed

+146
-2
lines changed

lm_eval/models/hf_vlms.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
handle_stop_sequences,
1818
pad_and_concat,
1919
replace_placeholders,
20+
resize_image,
2021
stop_sequences_criteria,
2122
)
2223

@@ -45,10 +46,23 @@ def __init__(
4546
# TODO: handle whitespace in image placeholder (replacement)
4647
max_images: Optional[int] = 999,
4748
convert_img_format=False,
49+
# For image resizing
4850
min_pixels: Optional[int] = None,
4951
max_pixels: Optional[int] = None,
52+
image_width: Optional[int] = None,
53+
image_height: Optional[int] = None,
54+
image_max_side: Optional[int] = None,
5055
**kwargs,
5156
):
57+
self.image_width = image_width
58+
self.image_height = image_height
59+
self.image_max_side = image_max_side
60+
if self.image_max_side and (self.image_width or self.image_height):
61+
raise ValueError(
62+
"Ambiguous config for image resize: you can not specify both "
63+
"image_max_side and (image_width or image_height)"
64+
)
65+
5266
# init pixels before calling tokenizer creation to avoid errors
5367
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
5468
{"max_pixels": max_pixels} if max_pixels else {}
@@ -646,7 +660,15 @@ def _collate(x):
646660
for chunk in chunks:
647661
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
648662

649-
visuals = [arg["visual"] for arg in aux_arguments]
663+
visuals = [
664+
[
665+
resize_image(
666+
img, self.image_width, self.image_height, self.image_max_side
667+
)
668+
for img in arg["visual"]
669+
]
670+
for arg in aux_arguments
671+
]
650672

651673
if not isinstance(contexts, list):
652674
contexts = list(

lm_eval/models/utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929

3030
if TYPE_CHECKING:
31+
from PIL import Image
3132
from transformers import PreTrainedTokenizerBase
3233
from transformers.configuration_utils import PretrainedConfig
3334

@@ -729,3 +730,103 @@ def handle_stop_sequences(
729730
if eos is not None and eos not in until:
730731
until.append(eos)
731732
return until
733+
734+
735+
def resize_image(
736+
image: "Image.Image",
737+
width: Optional[int] = None,
738+
height: Optional[int] = None,
739+
max_dimension: Optional[int] = None,
740+
keep_aspect_ratio: bool = True,
741+
resample_filter: Union[int, str] = "Image.BICUBIC",
742+
min_width: int = 1,
743+
min_height: int = 1,
744+
) -> "Image.Image":
745+
"""
746+
Resizes a PIL Image object with flexible options.
747+
748+
Args:
749+
image: The PIL Image object to resize.
750+
width: Target width in pixels.
751+
height: Target height in pixels.
752+
max_dimension: Maximum size for the longer dimension of the image.
753+
keep_aspect_ratio: If True (default) and both width and height are provided,
754+
the image is resized to fit within these dimensions while
755+
maintaining its aspect ratio. If False, the image is stretched
756+
to the exact width and height.
757+
resample_filter: The resampling filter to use for resizing.
758+
Defaults to Image.BICUBIC.
759+
min_width: Minimum width for the resized image. Defaults to 1.
760+
min_height: Minimum height for the resized image. Defaults to 1.
761+
762+
Returns:
763+
The resized PIL Image object. If no resize parameters are provided
764+
or if the image already meets the criteria, the original image is returned.
765+
766+
Order of precedence for resizing:
767+
1. If width AND height are provided:
768+
- If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio.
769+
- If keep_aspect_ratio is False: Resizes to exact dimensions (may distort).
770+
2. Else if only width is provided: Calculates height proportionally.
771+
3. Else if only height is provided: Calculates width proportionally.
772+
4. Else if max_dimension is provided: Resizes the longest side to max_dimension
773+
and scales the other side proportionally.
774+
5. If none of the above are provided, returns the original image.
775+
"""
776+
original_width, original_height = image.size
777+
778+
# If no arguments are provided, return the original image
779+
if width is None and height is None and max_dimension is None:
780+
return image
781+
782+
new_width = original_width
783+
new_height = original_height
784+
785+
if width is not None and height is not None:
786+
# No resize needed if image is already smaller than target dimensions
787+
if original_width <= width and original_height <= height:
788+
return image
789+
790+
if keep_aspect_ratio:
791+
# Calculate the ratio to fit within the target dimensions
792+
ratio = min(width / original_width, height / original_height)
793+
new_width = int(original_width * ratio)
794+
new_height = int(original_height * ratio)
795+
else:
796+
# Stretch to exact dimensions
797+
new_width = width
798+
new_height = height
799+
elif width is not None:
800+
# No resize needed if width is already smaller
801+
if original_width <= width:
802+
return image
803+
# Calculate height proportionally
804+
new_width = width
805+
new_height = int((original_height / original_width) * new_width)
806+
elif height is not None:
807+
# No resize needed if height is already smaller
808+
if original_height <= height:
809+
return image
810+
# Calculate width proportionally
811+
new_height = height
812+
new_width = int((original_width / original_height) * new_height)
813+
elif max_dimension is not None:
814+
# No resize needed if both dimensions are smaller than max_dimension
815+
if max(original_height, original_width) <= max_dimension:
816+
return image
817+
818+
if original_width > original_height:
819+
# Width is the longer side
820+
new_width = max_dimension
821+
new_height = int((original_height / original_width) * new_width)
822+
else:
823+
# Height is the longer side or sides are equal
824+
new_height = max_dimension
825+
new_width = int((original_width / original_height) * new_height)
826+
827+
# Ensure dimensions are at least minimum values
828+
new_width = max(min_width, new_width)
829+
new_height = max(min_height, new_height)
830+
831+
# Perform the resize operation with the calculated dimensions
832+
return image.resize((new_width, new_height), resample_filter)

lm_eval/models/vllm_vlms.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Collator,
1313
handle_stop_sequences,
1414
replace_placeholders,
15+
resize_image,
1516
undistribute,
1617
)
1718
from lm_eval.models.vllm_causallms import VLLM
@@ -44,8 +45,20 @@ def __init__(
4445
interleave: bool = True,
4546
# TODO<baber>: handle max_images and limit_mm_per_prompt better
4647
max_images: int = 999,
48+
image_width: Optional[int] = None,
49+
image_height: Optional[int] = None,
50+
image_max_side: Optional[int] = None,
4751
**kwargs,
4852
):
53+
self.image_width = image_width
54+
self.image_height = image_height
55+
self.image_max_side = image_max_side
56+
if self.image_max_side and (self.image_width or self.image_height):
57+
raise ValueError(
58+
"Ambiguous config for image resize: you can not specify both "
59+
"image_max_side and (image_width or image_height)"
60+
)
61+
4962
if max_images != 999:
5063
kwargs["limit_mm_per_prompt"] = {"image": max_images}
5164
eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}")
@@ -239,7 +252,15 @@ def _collate(x):
239252
for chunk in chunks:
240253
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
241254

242-
visuals = [arg["visual"] for arg in aux_arguments]
255+
visuals = [
256+
[
257+
resize_image(
258+
img, self.image_width, self.image_height, self.image_max_side
259+
)
260+
for img in arg["visual"]
261+
]
262+
for arg in aux_arguments
263+
]
243264

244265
if not isinstance(contexts, list):
245266
contexts = list(

0 commit comments

Comments
 (0)