Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def prompt_progress_callback(percent):
model_kit,
prompt_tokens,
images_b64=images_base64,
max_image_size=1024,
stop_strings=args.stop_strings,
max_tokens=1024,
top_logprobs=args.top_logprobs,
Expand Down
9 changes: 9 additions & 0 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def create_generator(
*,
prompt_progress_callback: Optional[Callable[[float], bool]] = None,
images_b64: Optional[List[str]] = None,
max_image_size: Optional[int] = None,
stop_strings: Optional[List[str]] = None,
top_logprobs: Optional[int] = None,
repetition_penalty: Optional[float] = None,
Expand Down Expand Up @@ -171,6 +172,8 @@ def create_generator(
generation progress as a float between 0 and 100. Callback should return True to continue
prompt processing, or False to stop generation
images_b64 (Optional[List[str]]): List of base64-encoded images for vision-language models
max_image_size (Optional[int]): Maximum dimension for images (assumes square). Images will be
resized to (max_image_size, max_image_size) if they exceed this size. If None, no resizing.
stop_strings (Optional[List[str]]): List of strings that will trigger generation to stop
when encountered
top_logprobs (Optional[int]): Number of top token probabilities to return per token
Expand Down Expand Up @@ -238,13 +241,19 @@ def create_generator(
model_kit, draft_model, num_draft_tokens, generate_args
)

# Convert max_image_size to tuple format (assumes square images)
max_image_size_tuple = (
(max_image_size, max_image_size) if max_image_size is not None else None
)

# Process prompt
try:
input_tokens, input_embeddings = model_kit.process_prompt(
prompt_tokens,
images_b64,
prompt_progress_callback,
generate_args,
max_image_size_tuple,
speculative_decoding_toggle,
)
except StopPromptProcessing:
Expand Down
3 changes: 2 additions & 1 deletion mlx_engine/model_kit/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def process_prompt(
images_b64: Optional[List[str]],
prompt_progress_callback: Optional[Callable[[float], bool]],
generate_args: dict,
max_image_size: tuple[int, int] | None,
speculative_decoding_toggle: Optional[bool] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
### TEXT-ONLY PROCESS_PROMPT ###
Expand Down Expand Up @@ -170,7 +171,7 @@ def process_prompt(
)
self._cross_prompt_cache_active = False
input_ids, embeddings = self.vision_add_on.compute_embeddings(
self.model, prompt_tokens, images_b64
self.model, prompt_tokens, images_b64, max_size=max_image_size
)
return input_ids, embeddings

Expand Down
7 changes: 7 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Returns input ids and input embeddings for the language model after text/image merging of the prompt.

Args:
text_model: Text model for embedding tokens
prompt_tokens: Input prompt tokens
images_b64: List of base64-encoded images
max_size: Maximum image size as (width, height) tuple. If None, no resizing.
"""
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""Compute input_ids and embeddings for text with images."""
input_ids, pixel_values, attention_mask, other_model_inputs = (
Expand All @@ -57,6 +58,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)
input_embeddings = text_model.language_model.model.embed_tokens(input_ids)
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""Compute input_ids and embeddings for text with images."""
input_ids, pixel_values, attention_mask, other_model_inputs = (
Expand All @@ -102,6 +103,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)
assert input_ids is not None
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/lfm2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute embeddings for text with images.
Expand All @@ -71,6 +72,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)

Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute embeddings for text with images.
Expand All @@ -61,6 +62,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)

Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""Compute input_ids and embeddings for text with images."""
input_ids, pixel_values, attention_mask, other_model_inputs = (
Expand All @@ -59,6 +60,7 @@ def compute_embeddings(
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)
)
input_embeddings = text_model.language_model.model.embed_tokens(input_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ def common_process_prompt_with_images(
images_b64: List[str],
processor: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
config, # expected to be a ModelConfig object as defined by mlx-vlm. Can vary by model
max_size: tuple[int, int] | None,
) -> ProcessedImagePrompt:
"""
Common prompt processing used by mlx-vlm vision add-ons.
Returns a named tuple with all processed inputs.

Args:
prompt_tokens: Input prompt tokens
images_b64: List of base64-encoded images
processor: Tokenizer/processor for the model
config: Model configuration object
max_size: Maximum image size as (width, height) tuple. If None, no resizing.
"""
if len(images_b64) == 0:
raise ValueError("Images must be non-empty")
Expand All @@ -37,7 +45,7 @@ def common_process_prompt_with_images(
logger.info(f"Prompt dump: {prompt}\n")

images = convert_to_pil(images_b64)
images = custom_resize(images)
images = custom_resize(images, max_size=max_size)

if hasattr(config, "image_token_index"):
image_token_index = config.image_token_index
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for text with images.
Expand All @@ -86,4 +87,5 @@ def compute_embeddings(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
qwen_vl_version=2,
max_size=max_size,
)
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for text with images.
Expand All @@ -58,4 +59,5 @@ def compute_embeddings(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
qwen_vl_version=3,
max_size=max_size,
)
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compute_embeddings(
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for text with images.
Expand All @@ -58,4 +59,5 @@ def compute_embeddings(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
qwen_vl_version=3,
max_size=max_size,
)
4 changes: 3 additions & 1 deletion mlx_engine/model_kit/vision_add_ons/qwen_vl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def compute_qwen_vl_embeddings(
prompt_tokens: mx.array,
images_b64: list[str],
qwen_vl_version: int,
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""
Compute input_ids and embeddings for Qwen2-VL, Qwen2.5-VL, and Qwen3-VL models.
Expand All @@ -23,14 +24,15 @@ def compute_qwen_vl_embeddings(
prompt_tokens: Input prompt tokens
images_b64: List of base64-encoded images
qwen_vl_version: Version number (2 for Qwen2/2.5-VL, 3 for Qwen3-VL)
max_size: Maximum image size as (width, height) tuple. If None, no resizing.

Returns:
Tuple of (input_ids, final_embeddings) with batch dimension removed
"""

# Convert and resize images
images = convert_to_pil(images_b64)
images = custom_resize(images, should_pad=False)
images = custom_resize(images, max_size=max_size, should_pad=False)

# Build prompt text
tokens = (
Expand Down
8 changes: 5 additions & 3 deletions mlx_engine/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def convert_to_pil(images_b64: List[str]) -> List[PIL.Image.Image]:
]


def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True):
def custom_resize(pil_images, max_size=None, should_pad=True):
"""
Resize and optionally pad a list of PIL images.

Expand All @@ -26,7 +26,7 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True):
Args:
pil_images (list): A list of PIL Image objects to be processed.
max_size (tuple): Maximum allowed dimensions (width, height) for the images.
Defaults to (1000, 1000).
If None, no resizing is performed.
should_pad (bool): Whether to pad the images to the same size.
Defaults to True.

Expand All @@ -49,7 +49,9 @@ def custom_resize(pil_images, max_size=(1000, 1000), should_pad=True):
f"Image {i + 1}: Original size {original_size}",
)

if img.width > max_size[0] or img.height > max_size[1]:
if max_size is not None and (
img.width > max_size[0] or img.height > max_size[1]
):
if img.width > img.height:
new_width = max_size[0]
new_height = int(new_width / aspect_ratio)
Expand Down
3 changes: 2 additions & 1 deletion mlx_engine/vision_model_kit/vision_model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def process_prompt(
images_b64: Optional[List[str]],
prompt_progress_callback,
generate_args,
max_image_size: tuple[int, int] | None,
speculative_decoding_toggle: Optional[bool] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Expand All @@ -115,7 +116,7 @@ def process_prompt(
self._reset_for_prediction()

self.model.process_prompt_with_images(
images_b64, prompt_tokens, self.processor, self.detokenizer
images_b64, prompt_tokens, self.processor, self.detokenizer, max_image_size
)
self.has_processed_prompt = True

Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/vision_model_kit/vision_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def process_prompt_with_images(
prompt_tokens: mx.array,
processor,
detokenizer,
max_image_size: tuple[int, int] | None,
):
"""
This method generates the input_ids, pixel_values, and mask for the vision model
Expand Down Expand Up @@ -195,6 +196,7 @@ def process_prompt_with_images(
images_b64=images_b64,
processor=processor,
config=self.vision_model.config,
max_size=max_image_size,
)

# Set class attributes from the processed result
Expand Down
6 changes: 6 additions & 0 deletions tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from textwrap import dedent


MAX_IMAGE_SIZE = 1024


class TestVisionModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -58,6 +61,7 @@ def toucan_test_runner(
model_kit=model_kit,
prompt_tokens=prompt_tokens,
images_b64=([self.toucan_image_b64] if not text_only else None),
max_image_size=MAX_IMAGE_SIZE,
seed=0,
max_tokens=30,
temp=0.0,
Expand Down Expand Up @@ -276,6 +280,7 @@ def generate_text(prompt, images_b64):
model_kit=model_kit,
prompt_tokens=prompt_tokens,
images_b64=images_b64,
max_image_size=MAX_IMAGE_SIZE,
seed=0,
temp=0.0,
max_tokens=50,
Expand Down Expand Up @@ -708,6 +713,7 @@ def progress_callback(progress: float) -> bool:
model_kit=model_kit,
prompt_tokens=prompt_tokens,
images_b64=[self.toucan_image_b64],
max_image_size=MAX_IMAGE_SIZE,
seed=0,
temp=0.0,
max_tokens=1, # We only care about pre-fill in this test
Expand Down