Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def prompt_progress_callback(percent):
model_kit,
prompt_tokens,
images_b64=images_base64,
max_image_size=(1024, 1024),
stop_strings=args.stop_strings,
max_tokens=1024,
top_logprobs=args.top_logprobs,
Expand Down
5 changes: 5 additions & 0 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def create_generator(
*,
prompt_progress_callback: Optional[Callable[[float], bool]] = None,
images_b64: Optional[List[str]] = None,
max_image_size: Optional[tuple[int, int]] = None,
stop_strings: Optional[List[str]] = None,
top_logprobs: Optional[int] = None,
repetition_penalty: Optional[float] = None,
Expand Down Expand Up @@ -171,6 +172,9 @@ def create_generator(
generation progress as a float between 0 and 100. Callback should return True to continue
prompt processing, or False to stop generation
images_b64 (Optional[List[str]]): List of base64-encoded images for vision-language models
max_image_size (Optional[tuple[int, int]]): Maximum dimensions (width, height) for images.
Images will be resized to fit within these dimensions while maintaining aspect ratio if
they exceed this size. If None, no resizing.
stop_strings (Optional[List[str]]): List of strings that will trigger generation to stop
when encountered
top_logprobs (Optional[int]): Number of top token probabilities to return per token
Expand Down Expand Up @@ -245,6 +249,7 @@ def create_generator(
images_b64,
prompt_progress_callback,
generate_args,
max_image_size,
speculative_decoding_toggle,
)
except StopPromptProcessing:
Expand Down
3 changes: 2 additions & 1 deletion mlx_engine/model_kit/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def process_prompt(
images_b64: Optional[List[str]],
prompt_progress_callback: Optional[Callable[[float], bool]],
generate_args: dict,
max_image_size: tuple[int, int] | None,
speculative_decoding_toggle: Optional[bool] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
### TEXT-ONLY PROCESS_PROMPT ###
Expand Down Expand Up @@ -170,7 +171,7 @@ def process_prompt(
)
self._cross_prompt_cache_active = False
input_ids, embeddings = self.vision_add_on.compute_embeddings(
self.model, prompt_tokens, images_b64
self.model, prompt_tokens, images_b64, max_size=max_image_size
)
return input_ids, embeddings

Expand Down
7 changes: 7 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Returns input ids and input embeddings for the language model after text/image merging of the prompt.

Args:
text_model: Text model for embedding tokens
prompt_tokens: Input prompt tokens
images_b64: List of base64-encoded images
max_size: Maximum image size as (width, height) tuple. If None, no resizing.
"""
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""Compute input_ids and embeddings for text with images."""
input_ids, pixel_values, attention_mask, other_model_inputs = (
Expand All @@ -57,6 +58,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)
input_embeddings = text_model.language_model.model.embed_tokens(input_ids)
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""Compute input_ids and embeddings for text with images."""
input_ids, pixel_values, attention_mask, other_model_inputs = (
Expand All @@ -102,6 +103,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)
assert input_ids is not None
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/lfm2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute embeddings for text with images.
Expand All @@ -71,6 +72,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)

Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute embeddings for text with images.
Expand All @@ -61,6 +62,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)

Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""Compute input_ids and embeddings for text with images."""
input_ids, pixel_values, attention_mask, other_model_inputs = (
Expand All @@ -59,6 +60,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)
input_embeddings = text_model.language_model.model.embed_tokens(input_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ def common_process_prompt_with_images(
images_b64: List[str],
processor: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
config, # expected to be a ModelConfig object as defined by mlx-vlm. Can vary by model
max_size: tuple[int, int] | None,
) -> ProcessedImagePrompt:
"""
Common prompt processing used by mlx-vlm vision add-ons.
Returns a named tuple with all processed inputs.

Args:
prompt_tokens: Input prompt tokens
images_b64: List of base64-encoded images
processor: Tokenizer/processor for the model
config: Model configuration object
max_size: Maximum image size as (width, height) tuple. If None, no resizing.
"""
if len(images_b64) == 0:
raise ValueError("Images must be non-empty")
Expand All @@ -37,7 +45,7 @@ def common_process_prompt_with_images(
logger.info(f"Prompt dump: {prompt}\n")

images = convert_to_pil(images_b64)
images = custom_resize(images)
images = custom_resize(images, max_size=max_size)

if hasattr(config, "image_token_index"):
image_token_index = config.image_token_index
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for text with images.
Expand All @@ -86,4 +87,5 @@ def compute_embeddings(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
qwen_vl_version=2,
max_size=max_size,
)
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for text with images.
Expand All @@ -58,4 +59,5 @@ def compute_embeddings(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
qwen_vl_version=3,
max_size=max_size,
)
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for text with images.
Expand All @@ -58,4 +59,5 @@ def compute_embeddings(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
qwen_vl_version=3,
max_size=max_size,
)
4 changes: 3 additions & 1 deletion mlx_engine/model_kit/vision_add_ons/qwen_vl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def compute_qwen_vl_embeddings(
prompt_tokens: mx.array,
images_b64: list[str],
qwen_vl_version: int,
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for Qwen2-VL, Qwen2.5-VL, and Qwen3-VL models.
Expand All @@ -23,14 +24,15 @@ def compute_qwen_vl_embeddings(
prompt_tokens: Input prompt tokens
images_b64: List of base64-encoded images
qwen_vl_version: Version number (2 for Qwen2/2.5-VL, 3 for Qwen3-VL)
max_size: Maximum image size as (width, height) tuple. If None, no resizing.

Returns:
Tuple of (input_ids, final_embeddings) with batch dimension removed
"""

# Convert and resize images
images = convert_to_pil(images_b64)
images = custom_resize(images, should_pad=False)
images = custom_resize(images, max_size=max_size, should_pad=False)

# Build prompt text
tokens = (
Expand Down
24 changes: 20 additions & 4 deletions mlx_engine/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
logger = logging.getLogger(__name__)


def convert_to_pil(images_b64: List[str]) -> List[PIL.Image.Image]:
def convert_to_pil(images_b64: List[str]) -> list[PIL.Image.Image]:
"""Convert a list of base64 strings to PIL Images"""
return [
PIL.Image.open(BytesIO(base64.b64decode(img))).convert("RGB")
for img in images_b64 or []
]


def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True):
def custom_resize(
pil_images: list[PIL.Image.Image],
*,
max_size: tuple[int, int] | None,
should_pad: bool = True,
):
"""
Resize and optionally pad a list of PIL images.

Expand All @@ -26,7 +31,7 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True):
Args:
pil_images (list): A list of PIL Image objects to be processed.
max_size (tuple): Maximum allowed dimensions (width, height) for the images.
Defaults to (1000, 1000).
If None, no resizing is performed.
should_pad (bool): Whether to pad the images to the same size.
Defaults to True.

Expand All @@ -38,6 +43,15 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True):
Side effects:
Writes progress and status messages to sys.stderr.
"""
# Validate max_size parameter
if max_size is not None:
if not isinstance(max_size, tuple) or len(max_size) != 2:
raise ValueError(
"max_size must be a tuple of (width, height), e.g., (1024, 1024)"
)
if not all(isinstance(dim, int) and dim > 0 for dim in max_size):
raise ValueError("max_size dimensions must be positive integers")

resized_images = []
max_width, max_height = 0, 0

Expand All @@ -49,7 +63,9 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True):
f"Image {i + 1}: Original size {original_size}",
)

if img.width > max_size[0] or img.height > max_size[1]:
if max_size is not None and (
img.width > max_size[0] or img.height > max_size[1]
):
if img.width > img.height:
new_width = max_size[0]
new_height = int(new_width / aspect_ratio)
Expand Down
3 changes: 2 additions & 1 deletion mlx_engine/vision_model_kit/vision_model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def process_prompt(
images_b64: Optional[List[str]],
prompt_progress_callback,
generate_args,
max_image_size: tuple[int, int] | None,
speculative_decoding_toggle: Optional[bool] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Expand All @@ -115,7 +116,7 @@ def process_prompt(
self._reset_for_prediction()

self.model.process_prompt_with_images(
images_b64, prompt_tokens, self.processor, self.detokenizer
images_b64, prompt_tokens, self.processor, self.detokenizer, max_image_size
)
self.has_processed_prompt = True

Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/vision_model_kit/vision_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def process_prompt_with_images(
prompt_tokens: mx.array,
processor,
detokenizer,
max_image_size: tuple[int, int] | None,
):
"""
This method generates the input_ids, pixel_values, and mask for the vision model
Expand Down Expand Up @@ -195,6 +196,7 @@ def process_prompt_with_images(
images_b64=images_b64,
processor=processor,
config=self.vision_model.config,
max_size=max_image_size,
)

# Set class attributes from the processed result
Expand Down
6 changes: 6 additions & 0 deletions tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from textwrap import dedent


MAX_IMAGE_SIZE = (1024, 1024)


class TestVisionModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -58,6 +61,7 @@ def toucan_test_runner(
model_kit=model_kit,
prompt_tokens=prompt_tokens,
images_b64=([self.toucan_image_b64] if not text_only else None),
max_image_size=MAX_IMAGE_SIZE,
seed=0,
max_tokens=30,
temp=0.0,
Expand Down Expand Up @@ -276,6 +280,7 @@ def generate_text(prompt, images_b64):
model_kit=model_kit,
prompt_tokens=prompt_tokens,
images_b64=images_b64,
max_image_size=MAX_IMAGE_SIZE,
seed=0,
temp=0.0,
max_tokens=50,
Expand Down Expand Up @@ -708,6 +713,7 @@ def progress_callback(progress: float) -> bool:
model_kit=model_kit,
prompt_tokens=prompt_tokens,
images_b64=[self.toucan_image_b64],
max_image_size=MAX_IMAGE_SIZE,
seed=0,
temp=0.0,
max_tokens=1, # We only care about pre-fill in this test
Expand Down