Skip to content

Commit a602201

Browse files
committed
Fix tests
1 parent c4868e0 commit a602201

File tree

4 files changed

+235
-36
lines changed

4 files changed

+235
-36
lines changed

CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,6 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM)
650650
list(APPEND _executorch_extensions tokenizers)
651651
endif()
652652

653-
if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
654-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
655-
endif()
656-
657653
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
658654
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
659655
install(
@@ -904,6 +900,10 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
904900
list(APPEND _executorch_extensions extension_llm_runner)
905901
endif()
906902

903+
if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
904+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
905+
endif()
906+
907907
if(EXECUTORCH_BUILD_KERNELS_LLM)
908908
# TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom
909909
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops)

extension/llm/runner/__init__.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,183 @@
3131
)
3232

3333

34+
import logging
35+
from typing import Any, Callable, Dict, List, Optional, Union
36+
37+
import torch
38+
39+
40+
def _find_image_token_runs(
41+
input_ids: torch.Tensor, image_token_id: Optional[int]
42+
) -> List[tuple[int, int, int]]:
43+
"""Return contiguous runs (start, end, length) of image_token_id in input_ids.
44+
45+
input_ids must be a 1D torch.Tensor. If image_token_id is None, returns an empty list.
46+
"""
47+
if image_token_id is None:
48+
return []
49+
50+
ids_list = input_ids.tolist()
51+
runs: List[tuple[int, int, int]] = []
52+
i = 0
53+
L = len(ids_list)
54+
while i < L:
55+
if ids_list[i] == image_token_id:
56+
j = i
57+
while j < L and ids_list[j] == image_token_id:
58+
j += 1
59+
runs.append((i, j - 1, j - i))
60+
i = j
61+
else:
62+
i += 1
63+
64+
return runs
65+
66+
67+
def _hf_to_multimodal_inputs( # noqa: C901
68+
inputs: Dict[str, Any], image_token_id: Optional[int] = None
69+
) -> List[MultimodalInput]:
70+
"""Convert a HuggingFace AutoProcessor dict to ExecuTorch MultimodalInputs.
71+
Currently only support 1 image inside the input.
72+
73+
Args:
74+
- inputs: A dictionary containing the input data.
75+
- image_token_id: The token ID for the image, if present.
76+
77+
`inputs` expected keys:
78+
- 'input_ids': torch.Tensor of shape (L,) or (1, L)
79+
- Optional 'pixel_values': torch.Tensor; if present, must also provide
80+
'image_token_id' (or alias 'image_token_index') and there must be
81+
exactly one image token occurrence in input_ids.
82+
83+
Raises:
84+
RuntimeError: missing keys, invalid shapes/dtypes, or unsupported cases.
85+
"""
86+
if "input_ids" not in inputs:
87+
raise RuntimeError("HF inputs dict must contain 'input_ids' (torch.Tensor)")
88+
89+
input_ids = inputs["input_ids"]
90+
if not isinstance(input_ids, torch.Tensor):
91+
raise RuntimeError("'input_ids' must be a torch.Tensor")
92+
93+
if input_ids.dim() == 2:
94+
if input_ids.size(0) != 1:
95+
raise RuntimeError(
96+
"Expected 'input_ids' with batch size 1 when 2D (shape (1, L))"
97+
)
98+
input_ids = input_ids.squeeze(0)
99+
if input_ids.dim() != 1:
100+
raise RuntimeError("'input_ids' must be 1D (L) or 2D with batch size 1")
101+
102+
has_pixel_values = "pixel_values" in inputs
103+
104+
# If pixel_values in dict, require image_token_id
105+
if has_pixel_values and image_token_id is None:
106+
raise RuntimeError("'pixel_values' provided but missing 'image_token_id'")
107+
108+
# If there are image token ids but no pixel_values, it's an error
109+
if (
110+
image_token_id is not None
111+
and (input_ids == image_token_id).any().item()
112+
and not has_pixel_values
113+
):
114+
raise RuntimeError(
115+
"Found image token(s) in input_ids but 'pixel_values' not provided"
116+
)
117+
118+
# No images: return a single tokens input
119+
if not has_pixel_values:
120+
return [make_token_input(input_ids.to(torch.long).tolist())]
121+
122+
# Determine number of images from pixel_values shape
123+
pv = inputs["pixel_values"]
124+
if not isinstance(pv, torch.Tensor):
125+
raise RuntimeError(
126+
"'pixel_values' must be a torch.Tensor, run with `return_tensors='pt'` in HF processor"
127+
)
128+
if pv.dim() == 4:
129+
num_images = int(pv.size(0))
130+
elif pv.dim() == 3:
131+
num_images = 1
132+
else:
133+
raise RuntimeError(
134+
f"'pixel_values' must be 3D (C,H,W) or 4D (N,C,H,W)/(N,H,W,C), got shape {pv.shape}"
135+
)
136+
137+
# Only support batch size 1 for now:
138+
if num_images != 1:
139+
raise RuntimeError("Only 1 image is supported for now")
140+
# Find contiguous runs of image_token_id in input_ids
141+
runs = _find_image_token_runs(input_ids, image_token_id)
142+
143+
if len(runs) == 0:
144+
raise RuntimeError(
145+
"'pixel_values' provided but no occurrence of 'image_token_id' in input_ids"
146+
)
147+
148+
# Support only one image/run for now; enforce exact match
149+
if num_images != 1 or len(runs) != 1:
150+
raise RuntimeError(
151+
f"Mismatch between images and image token runs: images={num_images}, runs={len(runs)} (only batch=1 and a single contiguous run are supported)"
152+
)
153+
154+
first, last, _ = runs[0]
155+
156+
combined: List[MultimodalInput] = []
157+
if first > 0:
158+
combined.append(make_token_input(input_ids[:first].to(torch.long).tolist()))
159+
160+
# Use C++ checked creator for images (handles 3D/4D, CHW/HWC, uint8/float32)
161+
combined.append(make_image_input(inputs["pixel_values"]))
162+
163+
if (last + 1) < input_ids.numel():
164+
combined.append(make_token_input(input_ids[last + 1 :].to(torch.long).tolist()))
165+
166+
return combined
167+
168+
169+
def generate(
170+
runner: MultimodalRunner,
171+
inputs: Union[Dict[str, Any], List[MultimodalInput]],
172+
config: GenerationConfig,
173+
image_token_id: Optional[int] = None,
174+
token_callback: Optional[Callable[[str], None]] = None,
175+
stats_callback: Optional[Callable[[Stats], None]] = None,
176+
) -> None:
177+
"""Generate using an HF dict by converting to multimodal inputs internally, or using a list of MultimodalInput."""
178+
if isinstance(inputs, dict):
179+
logging.info(
180+
"Input is a dict, assuming it's coming from HF AutoProcessor.apply_chat_template(). Converting to multimodal inputs."
181+
)
182+
converted = _hf_to_multimodal_inputs(inputs, image_token_id=image_token_id)
183+
else:
184+
converted = inputs
185+
186+
runner.generate(converted, config, token_callback, stats_callback)
187+
188+
189+
def generate_text(
190+
runner: MultimodalRunner,
191+
inputs: Union[Dict[str, Any], List[MultimodalInput]],
192+
config: GenerationConfig,
193+
image_token_id: Optional[int] = None,
194+
) -> str:
195+
"""Generate using an HF dict by converting to multimodal inputs internally, or using a list of MultimodalInput."""
196+
if isinstance(inputs, dict):
197+
logging.info(
198+
"Input is a dict, assuming it's coming from HF AutoProcessor.apply_chat_template(). Converting to multimodal inputs."
199+
)
200+
converted = _hf_to_multimodal_inputs(inputs, image_token_id=image_token_id)
201+
else:
202+
converted = inputs
203+
204+
return runner.generate_text(converted, config)
205+
206+
207+
setattr(MultimodalRunner, "generate", generate) # noqa B010
208+
setattr(MultimodalRunner, "generate_text", generate_text) # noqa B010
209+
210+
34211
__all__ = [
35212
"GenerationConfig",
36213
"Image",

extension/llm/runner/_llm_runner.pyi

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,31 @@ class MultimodalRunner:
368368
Raises:
369369
RuntimeError: If generation fails
370370
"""
371+
...
372+
373+
def generate(
374+
self,
375+
inputs: dict,
376+
config: GenerationConfig,
377+
token_callback: Optional[Callable[[str], None]] = None,
378+
stats_callback: Optional[Callable[[Stats], None]] = None,
379+
) -> None:
380+
"""
381+
Generate text directly from a HuggingFace processor dict.
382+
383+
Expects at least 'input_ids' (torch.Tensor). If 'pixel_values' is provided,
384+
an 'image_token_id' (or 'image_token_index') must also be present to locate
385+
the image position(s) in input_ids.
386+
387+
Args:
388+
inputs: HF processor outputs (e.g., from AutoProcessor.apply_chat_template)
389+
config: Generation configuration
390+
token_callback: Optional per-token callback
391+
stats_callback: Optional stats callback
392+
393+
Raises:
394+
RuntimeError: If required keys are missing, shapes are invalid, or generation fails
395+
"""
371396
...
372397

373398
def prefill(self, inputs: List[MultimodalInput]) -> None:
@@ -399,6 +424,14 @@ class MultimodalRunner:
399424
Raises:
400425
RuntimeError: If generation fails
401426
"""
427+
...
428+
429+
def generate_text(self, inputs: dict, config: GenerationConfig) -> str:
430+
"""
431+
Generate text directly from a HuggingFace processor dict and return as string.
432+
433+
See generate(inputs: dict, ...) for expected keys and constraints.
434+
"""
402435
...
403436

404437
def stop(self) -> None:

extension/llm/runner/test/test_runner_pybindings.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import tempfile
1717
import unittest
1818

19-
import numpy as np
19+
import torch
2020
from executorch.extension.llm.runner import (
2121
GenerationConfig,
2222
Image,
@@ -118,25 +118,18 @@ class TestImage(unittest.TestCase):
118118

119119
def test_creation(self):
120120
"""Test creating an Image object."""
121-
image = Image()
121+
# Construct using binding constructor (uint8 data)
122+
image = Image([1, 2, 3, 4], 2, 2, 1)
122123

123-
# Set properties
124-
image.data = [1, 2, 3, 4]
125-
image.width = 2
126-
image.height = 2
127-
image.channels = 1
128-
129-
self.assertEqual(image.data, [1, 2, 3, 4])
124+
# Properties are read-only
125+
self.assertEqual(image.uint8_data, [1, 2, 3, 4])
130126
self.assertEqual(image.width, 2)
131127
self.assertEqual(image.height, 2)
132128
self.assertEqual(image.channels, 1)
133129

134130
def test_repr(self):
135131
"""Test string representation."""
136-
image = Image()
137-
image.width = 640
138-
image.height = 480
139-
image.channels = 3
132+
image = Image([0] * (480 * 640 * 3), 640, 480, 3)
140133

141134
repr_str = repr(image)
142135
self.assertIn("Image", repr_str)
@@ -164,33 +157,29 @@ def test_text_input(self):
164157
def test_image_input(self):
165158
"""Test creating an image MultimodalInput."""
166159
# Create an image
167-
image = Image()
168-
image.data = [255] * (100 * 100 * 3)
169-
image.width = 100
170-
image.height = 100
171-
image.channels = 3
160+
image = Image([255] * (100 * 100 * 3), 100, 100, 3)
172161

173162
# Test direct constructor
174163
image_input = MultimodalInput(image)
175164
self.assertTrue(image_input.is_image())
176165
self.assertFalse(image_input.is_text())
177166

178-
# Test helper function with numpy array
179-
img_array = np.ones((50, 60, 3), dtype=np.uint8) * 128
180-
image_input2 = make_image_input(img_array)
167+
# Test helper function with torch tensor (CHW)
168+
img_tensor = torch.ones((3, 50, 60), dtype=torch.uint8) * 128
169+
image_input2 = make_image_input(img_tensor)
181170
self.assertTrue(image_input2.is_image())
182171
self.assertFalse(image_input2.is_text())
183172

184173
def test_invalid_image_array(self):
185174
"""Test error handling for invalid image arrays."""
186-
# Wrong dimensions
175+
# Wrong dimensions (expects 3D or 4D tensor)
187176
with self.assertRaises(RuntimeError) as cm:
188-
make_image_input(np.ones((100,), dtype=np.uint8))
177+
make_image_input(torch.ones((100,), dtype=torch.uint8))
189178
self.assertIn("3-dimensional", str(cm.exception))
190179

191180
# Wrong number of channels
192181
with self.assertRaises(RuntimeError) as cm:
193-
make_image_input(np.ones((100, 100, 2), dtype=np.uint8))
182+
make_image_input(torch.ones((2, 100, 100), dtype=torch.uint8))
194183
self.assertIn("3 (RGB) or 4 (RGBA)", str(cm.exception))
195184

196185
def test_repr(self):
@@ -209,7 +198,7 @@ def test_repr(self):
209198
self.assertIn("...", repr_str2)
210199

211200
# Image input
212-
image = Image()
201+
image = Image([0, 0, 0], 1, 1, 3)
213202
image_input = MultimodalInput(image)
214203
repr_str3 = repr(image_input)
215204
self.assertIn("type=image", repr_str3)
@@ -256,14 +245,14 @@ def test_make_text_input(self):
256245

257246
def test_make_image_input(self):
258247
"""Test make_image_input helper."""
259-
# Create a test image array (RGB)
260-
img_array = np.zeros((100, 150, 3), dtype=np.uint8)
261-
img_array[:, :, 0] = 255 # Red channel
248+
# Create a test image tensor (RGB, CHW)
249+
img_tensor = torch.zeros((3, 100, 150), dtype=torch.uint8)
250+
img_tensor[0, :, :] = 255 # Red channel
262251

263-
image_input = make_image_input(img_array)
252+
image_input = make_image_input(img_tensor)
264253
self.assertTrue(image_input.is_image())
265254

266-
# Test with RGBA
267-
img_array_rgba = np.ones((50, 50, 4), dtype=np.uint8) * 128
268-
image_input_rgba = make_image_input(img_array_rgba)
255+
# Test with RGBA (CHW)
256+
img_tensor_rgba = torch.ones((4, 50, 50), dtype=torch.uint8) * 128
257+
image_input_rgba = make_image_input(img_tensor_rgba)
269258
self.assertTrue(image_input_rgba.is_image())

0 commit comments

Comments
 (0)