diff --git a/demo.py b/demo.py index 557c43d..dfcc3fd 100644 --- a/demo.py +++ b/demo.py @@ -237,6 +237,7 @@ def prompt_progress_callback(percent): model_kit, prompt_tokens, images_b64=images_base64, + max_image_size=(1024, 1024), stop_strings=args.stop_strings, max_tokens=1024, top_logprobs=args.top_logprobs, diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index ab80ef3..e96b62e 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -142,6 +142,7 @@ def create_generator( *, prompt_progress_callback: Optional[Callable[[float], bool]] = None, images_b64: Optional[List[str]] = None, + max_image_size: Optional[tuple[int, int]] = None, stop_strings: Optional[List[str]] = None, top_logprobs: Optional[int] = None, repetition_penalty: Optional[float] = None, @@ -171,6 +172,9 @@ def create_generator( generation progress as a float between 0 and 100. Callback should return True to continue prompt processing, or False to stop generation images_b64 (Optional[List[str]]): List of base64-encoded images for vision-language models + max_image_size (Optional[tuple[int, int]]): Maximum dimensions (width, height) for images. + Images will be resized to fit within these dimensions while maintaining aspect ratio if + they exceed this size. If None, no resizing. stop_strings (Optional[List[str]]): List of strings that will trigger generation to stop when encountered top_logprobs (Optional[int]): Number of top token probabilities to return per token @@ -245,6 +249,7 @@ def create_generator( images_b64, prompt_progress_callback, generate_args, + max_image_size, speculative_decoding_toggle, ) except StopPromptProcessing: diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 7c6f52c..20de93e 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -143,6 +143,7 @@ def process_prompt( images_b64: Optional[List[str]], prompt_progress_callback: Optional[Callable[[float], bool]], generate_args: dict, + max_image_size: tuple[int, int] | None, speculative_decoding_toggle: Optional[bool] = None, ) -> Tuple[mx.array, Optional[mx.array]]: ### TEXT-ONLY PROCESS_PROMPT ### @@ -170,7 +171,7 @@ def process_prompt( ) self._cross_prompt_cache_active = False input_ids, embeddings = self.vision_add_on.compute_embeddings( - self.model, prompt_tokens, images_b64 + self.model, prompt_tokens, images_b64, max_size=max_image_size ) return input_ids, embeddings diff --git a/mlx_engine/model_kit/vision_add_ons/base.py b/mlx_engine/model_kit/vision_add_ons/base.py index 669be8b..ad0149c 100644 --- a/mlx_engine/model_kit/vision_add_ons/base.py +++ b/mlx_engine/model_kit/vision_add_ons/base.py @@ -21,7 +21,14 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """ Returns input ids and input embeddings for the language model after text/image merging of the prompt. + + Args: + text_model: Text model for embedding tokens + prompt_tokens: Input prompt tokens + images_b64: List of base64-encoded images + max_size: Maximum image size as (width, height) tuple. If None, no resizing. """ diff --git a/mlx_engine/model_kit/vision_add_ons/gemma3.py b/mlx_engine/model_kit/vision_add_ons/gemma3.py index 03cb254..eb1c4cd 100644 --- a/mlx_engine/model_kit/vision_add_ons/gemma3.py +++ b/mlx_engine/model_kit/vision_add_ons/gemma3.py @@ -49,6 +49,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """Compute input_ids and embeddings for text with images.""" input_ids, pixel_values, attention_mask, other_model_inputs = ( @@ -57,6 +58,7 @@ def compute_embeddings( images_b64=images_b64, processor=self.processor, config=self.config, + max_size=max_size, ) ) input_embeddings = text_model.language_model.model.embed_tokens(input_ids) diff --git a/mlx_engine/model_kit/vision_add_ons/gemma3n.py b/mlx_engine/model_kit/vision_add_ons/gemma3n.py index 8e746bb..f5e5f5f 100644 --- a/mlx_engine/model_kit/vision_add_ons/gemma3n.py +++ b/mlx_engine/model_kit/vision_add_ons/gemma3n.py @@ -94,6 +94,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """Compute input_ids and embeddings for text with images.""" input_ids, pixel_values, attention_mask, other_model_inputs = ( @@ -102,6 +103,7 @@ def compute_embeddings( images_b64=images_b64, processor=self.processor, config=self.config, + max_size=max_size, ) ) assert input_ids is not None diff --git a/mlx_engine/model_kit/vision_add_ons/lfm2_vl.py b/mlx_engine/model_kit/vision_add_ons/lfm2_vl.py index 302ee3f..a0382a3 100644 --- a/mlx_engine/model_kit/vision_add_ons/lfm2_vl.py +++ b/mlx_engine/model_kit/vision_add_ons/lfm2_vl.py @@ -57,6 +57,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """ Compute embeddings for text with images. @@ -71,6 +72,7 @@ def compute_embeddings( images_b64=images_b64, processor=self.processor, config=self.config, + max_size=max_size, ) ) diff --git a/mlx_engine/model_kit/vision_add_ons/mistral3.py b/mlx_engine/model_kit/vision_add_ons/mistral3.py index 0ab4868..0205896 100644 --- a/mlx_engine/model_kit/vision_add_ons/mistral3.py +++ b/mlx_engine/model_kit/vision_add_ons/mistral3.py @@ -47,6 +47,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """ Compute embeddings for text with images. @@ -61,6 +62,7 @@ def compute_embeddings( images_b64=images_b64, processor=self.processor, config=self.config, + max_size=max_size, ) ) diff --git a/mlx_engine/model_kit/vision_add_ons/pixtral.py b/mlx_engine/model_kit/vision_add_ons/pixtral.py index 4426aa3..7f921c6 100644 --- a/mlx_engine/model_kit/vision_add_ons/pixtral.py +++ b/mlx_engine/model_kit/vision_add_ons/pixtral.py @@ -51,6 +51,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """Compute input_ids and embeddings for text with images.""" input_ids, pixel_values, attention_mask, other_model_inputs = ( @@ -59,6 +60,7 @@ def compute_embeddings( images_b64=images_b64, processor=self.processor, config=self.config, + max_size=max_size, ) ) input_embeddings = text_model.language_model.model.embed_tokens(input_ids) diff --git a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py index 5457a98..0dde970 100644 --- a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py +++ b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py @@ -21,10 +21,18 @@ def common_process_prompt_with_images( images_b64: List[str], processor: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], config, # expected to be a ModelConfig object as defined by mlx-vlm. Can vary by model + max_size: tuple[int, int] | None, ) -> ProcessedImagePrompt: """ Common prompt processing used by mlx-vlm vision add-ons. Returns a named tuple with all processed inputs. + + Args: + prompt_tokens: Input prompt tokens + images_b64: List of base64-encoded images + processor: Tokenizer/processor for the model + config: Model configuration object + max_size: Maximum image size as (width, height) tuple. If None, no resizing. """ if len(images_b64) == 0: raise ValueError("Images must be non-empty") @@ -37,7 +45,7 @@ def common_process_prompt_with_images( logger.info(f"Prompt dump: {prompt}\n") images = convert_to_pil(images_b64) - images = custom_resize(images) + images = custom_resize(images, max_size=max_size) if hasattr(config, "image_token_index"): image_token_index = config.image_token_index diff --git a/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py b/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py index 38b7fa3..124892b 100644 --- a/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py +++ b/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py @@ -75,6 +75,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """ Compute input_ids and embeddings for text with images. @@ -86,4 +87,5 @@ def compute_embeddings( prompt_tokens=prompt_tokens, images_b64=images_b64, qwen_vl_version=2, + max_size=max_size, ) diff --git a/mlx_engine/model_kit/vision_add_ons/qwen3_vl.py b/mlx_engine/model_kit/vision_add_ons/qwen3_vl.py index f7b685e..33bfa36 100644 --- a/mlx_engine/model_kit/vision_add_ons/qwen3_vl.py +++ b/mlx_engine/model_kit/vision_add_ons/qwen3_vl.py @@ -47,6 +47,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """ Compute input_ids and embeddings for text with images. @@ -58,4 +59,5 @@ def compute_embeddings( prompt_tokens=prompt_tokens, images_b64=images_b64, qwen_vl_version=3, + max_size=max_size, ) diff --git a/mlx_engine/model_kit/vision_add_ons/qwen3_vl_moe.py b/mlx_engine/model_kit/vision_add_ons/qwen3_vl_moe.py index 2122bba..0b5a4ff 100644 --- a/mlx_engine/model_kit/vision_add_ons/qwen3_vl_moe.py +++ b/mlx_engine/model_kit/vision_add_ons/qwen3_vl_moe.py @@ -47,6 +47,7 @@ def compute_embeddings( text_model: nn.Module, prompt_tokens: mx.array, images_b64: list[str], + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """ Compute input_ids and embeddings for text with images. @@ -58,4 +59,5 @@ def compute_embeddings( prompt_tokens=prompt_tokens, images_b64=images_b64, qwen_vl_version=3, + max_size=max_size, ) diff --git a/mlx_engine/model_kit/vision_add_ons/qwen_vl_utils.py b/mlx_engine/model_kit/vision_add_ons/qwen_vl_utils.py index e79af02..a1845b0 100644 --- a/mlx_engine/model_kit/vision_add_ons/qwen_vl_utils.py +++ b/mlx_engine/model_kit/vision_add_ons/qwen_vl_utils.py @@ -13,6 +13,7 @@ def compute_qwen_vl_embeddings( prompt_tokens: mx.array, images_b64: list[str], qwen_vl_version: int, + max_size: tuple[int, int] | None, ) -> tuple[mx.array, mx.array]: """ Compute input_ids and embeddings for Qwen2-VL, Qwen2.5-VL, and Qwen3-VL models. @@ -23,6 +24,7 @@ def compute_qwen_vl_embeddings( prompt_tokens: Input prompt tokens images_b64: List of base64-encoded images qwen_vl_version: Version number (2 for Qwen2/2.5-VL, 3 for Qwen3-VL) + max_size: Maximum image size as (width, height) tuple. If None, no resizing. Returns: Tuple of (input_ids, final_embeddings) with batch dimension removed @@ -30,7 +32,7 @@ def compute_qwen_vl_embeddings( # Convert and resize images images = convert_to_pil(images_b64) - images = custom_resize(images, should_pad=False) + images = custom_resize(images, max_size=max_size, should_pad=False) # Build prompt text tokens = ( diff --git a/mlx_engine/utils/image_utils.py b/mlx_engine/utils/image_utils.py index f49fb41..481f843 100644 --- a/mlx_engine/utils/image_utils.py +++ b/mlx_engine/utils/image_utils.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -def convert_to_pil(images_b64: List[str]) -> List[PIL.Image.Image]: +def convert_to_pil(images_b64: List[str]) -> list[PIL.Image.Image]: """Convert a list of base64 strings to PIL Images""" return [ PIL.Image.open(BytesIO(base64.b64decode(img))).convert("RGB") @@ -15,7 +15,12 @@ def convert_to_pil(images_b64: List[str]) -> List[PIL.Image.Image]: ] -def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True): +def custom_resize( + pil_images: list[PIL.Image.Image], + *, + max_size: tuple[int, int] | None, + should_pad: bool = True, +): """ Resize and optionally pad a list of PIL images. @@ -26,7 +31,7 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True): Args: pil_images (list): A list of PIL Image objects to be processed. max_size (tuple): Maximum allowed dimensions (width, height) for the images. - Defaults to (1000, 1000). + If None, no resizing is performed. should_pad (bool): Whether to pad the images to the same size. Defaults to True. @@ -38,6 +43,15 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True): Side effects: Writes progress and status messages to sys.stderr. """ + # Validate max_size parameter + if max_size is not None: + if not isinstance(max_size, tuple) or len(max_size) != 2: + raise ValueError( + "max_size must be a tuple of (width, height), e.g., (1024, 1024)" + ) + if not all(isinstance(dim, int) and dim > 0 for dim in max_size): + raise ValueError("max_size dimensions must be positive integers") + resized_images = [] max_width, max_height = 0, 0 @@ -49,7 +63,9 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True): f"Image {i + 1}: Original size {original_size}", ) - if img.width > max_size[0] or img.height > max_size[1]: + if max_size is not None and ( + img.width > max_size[0] or img.height > max_size[1] + ): if img.width > img.height: new_width = max_size[0] new_height = int(new_width / aspect_ratio) diff --git a/mlx_engine/vision_model_kit/vision_model_kit.py b/mlx_engine/vision_model_kit/vision_model_kit.py index dec10ce..ef8d4e3 100644 --- a/mlx_engine/vision_model_kit/vision_model_kit.py +++ b/mlx_engine/vision_model_kit/vision_model_kit.py @@ -101,6 +101,7 @@ def process_prompt( images_b64: Optional[List[str]], prompt_progress_callback, generate_args, + max_image_size: tuple[int, int] | None, speculative_decoding_toggle: Optional[bool] = None, ) -> Tuple[mx.array, Optional[mx.array]]: """ @@ -115,7 +116,7 @@ def process_prompt( self._reset_for_prediction() self.model.process_prompt_with_images( - images_b64, prompt_tokens, self.processor, self.detokenizer + images_b64, prompt_tokens, self.processor, self.detokenizer, max_image_size ) self.has_processed_prompt = True diff --git a/mlx_engine/vision_model_kit/vision_model_wrapper.py b/mlx_engine/vision_model_kit/vision_model_wrapper.py index b750f48..e71b41c 100644 --- a/mlx_engine/vision_model_kit/vision_model_wrapper.py +++ b/mlx_engine/vision_model_kit/vision_model_wrapper.py @@ -157,6 +157,7 @@ def process_prompt_with_images( prompt_tokens: mx.array, processor, detokenizer, + max_image_size: tuple[int, int] | None, ): """ This method generates the input_ids, pixel_values, and mask for the vision model @@ -195,6 +196,7 @@ def process_prompt_with_images( images_b64=images_b64, processor=processor, config=self.vision_model.config, + max_size=max_image_size, ) # Set class attributes from the processed result diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py index d711d25..ecae5c2 100644 --- a/tests/test_vision_models.py +++ b/tests/test_vision_models.py @@ -12,6 +12,9 @@ from textwrap import dedent +MAX_IMAGE_SIZE = (1024, 1024) + + class TestVisionModels(unittest.TestCase): @classmethod def setUpClass(cls): @@ -58,6 +61,7 @@ def toucan_test_runner( model_kit=model_kit, prompt_tokens=prompt_tokens, images_b64=([self.toucan_image_b64] if not text_only else None), + max_image_size=MAX_IMAGE_SIZE, seed=0, max_tokens=30, temp=0.0, @@ -276,6 +280,7 @@ def generate_text(prompt, images_b64): model_kit=model_kit, prompt_tokens=prompt_tokens, images_b64=images_b64, + max_image_size=MAX_IMAGE_SIZE, seed=0, temp=0.0, max_tokens=50, @@ -708,6 +713,7 @@ def progress_callback(progress: float) -> bool: model_kit=model_kit, prompt_tokens=prompt_tokens, images_b64=[self.toucan_image_b64], + max_image_size=MAX_IMAGE_SIZE, seed=0, temp=0.0, max_tokens=1, # We only care about pre-fill in this test