Skip to content

Commit 6685217

Browse files
committed
Add support for audio and token input
1 parent 6d05fc3 commit 6685217

File tree

9 files changed

+1510
-257
lines changed

9 files changed

+1510
-257
lines changed

extension/llm/runner/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@ if(EXECUTORCH_BUILD_PYBIND)
104104
CXX_VISIBILITY_PRESET "hidden"
105105
INTERPROCEDURAL_OPTIMIZATION TRUE
106106
)
107-
107+
if(APPLE)
108+
set(RPATH "@loader_path/../../pybindings")
109+
else()
110+
set(RPATH "$ORIGIN/../../pybindings")
111+
endif()
112+
set_target_properties(_llm_runner PROPERTIES INSTALL_RPATH ${RPATH})
108113
# Add include directories
109114
target_include_directories(
110115
_llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS}

extension/llm/runner/__init__.py

Lines changed: 7 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
from executorch.extension.llm.runner._llm_runner import ( # noqa: F401
2929
GenerationConfig,
3030
Image,
31+
make_audio_input,
3132
make_image_input,
33+
make_raw_audio_input,
3234
make_text_input,
35+
make_token_input,
3336
MultimodalInput,
34-
MultimodalRunner as _MultimodalRunnerCpp,
37+
MultimodalRunner,
3538
Stats,
3639
)
3740
except ImportError:
@@ -40,242 +43,6 @@
4043
)
4144

4245

43-
# Define the high-level Python wrapper for MultimodalRunner
44-
class MultimodalRunner:
45-
"""
46-
High-level Python wrapper for the ExecuTorch MultimodalRunner.
47-
48-
This class provides a convenient interface for running multimodal language models
49-
that can process text, images, and other modalities to generate text output.
50-
51-
Args:
52-
model_path: Path to the ExecuTorch model file (.pte)
53-
tokenizer_path: Path to the tokenizer file
54-
temperature: Default temperature for text generation (default: 0.8)
55-
device: Device to run on (currently only 'cpu' is supported)
56-
57-
Example:
58-
>>> runner = MultimodalRunner("model.pte", "tokenizer.bin")
59-
>>> inputs = [
60-
... runner.create_text_input("Describe this image:"),
61-
... runner.create_image_input("image.jpg")
62-
... ]
63-
>>> response = runner.generate_text(inputs, max_new_tokens=100)
64-
>>> print(response)
65-
"""
66-
67-
def __init__(
68-
self,
69-
model_path: Union[str, Path],
70-
tokenizer_path: Union[str, Path],
71-
temperature: float = 0.8,
72-
device: str = "cpu",
73-
):
74-
"""Initialize the MultimodalRunner."""
75-
if device != "cpu":
76-
raise ValueError(
77-
f"Currently only 'cpu' device is supported, got '{device}'"
78-
)
79-
80-
# Convert paths to strings
81-
model_path = str(Path(model_path).resolve())
82-
tokenizer_path = str(Path(tokenizer_path).resolve())
83-
84-
# Validate paths exist
85-
if not Path(model_path).exists():
86-
raise FileNotFoundError(f"Model file not found: {model_path}")
87-
if not Path(tokenizer_path).exists():
88-
raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}")
89-
90-
# Initialize the C++ runner
91-
self._runner = _MultimodalRunnerCpp(model_path, tokenizer_path, temperature)
92-
self._model_path = model_path
93-
self._tokenizer_path = tokenizer_path
94-
self._default_temperature = temperature
95-
96-
def create_text_input(self, text: str):
97-
"""
98-
Create a text input for multimodal processing.
99-
100-
Args:
101-
text: The input text string
102-
103-
Returns:
104-
A MultimodalInput object containing the text
105-
"""
106-
return make_text_input(text)
107-
108-
def create_image_input( # noqa: C901
109-
self, image: Union[str, Path, np.ndarray, "PILImage.Image"]
110-
):
111-
"""
112-
Create an image input for multimodal processing.
113-
114-
Args:
115-
image: Can be:
116-
- Path to an image file (str or Path)
117-
- NumPy array with shape (H, W, C) where C is 3 (RGB) or 4 (RGBA)
118-
- PIL Image object
119-
120-
Returns:
121-
A MultimodalInput object containing the image
122-
123-
Raises:
124-
ValueError: If the image format is not supported
125-
FileNotFoundError: If the image file doesn't exist
126-
"""
127-
if isinstance(image, (str, Path)):
128-
# Load image from file
129-
image_path = Path(image)
130-
if not image_path.exists():
131-
raise FileNotFoundError(f"Image file not found: {image_path}")
132-
133-
if HAS_PIL:
134-
pil_image = PILImage.open(image_path)
135-
# Convert to RGB if necessary
136-
if pil_image.mode != "RGB":
137-
pil_image = pil_image.convert("RGB")
138-
image = np.array(pil_image, dtype=np.uint8)
139-
else:
140-
# Try to use cv2 if available
141-
try:
142-
import cv2
143-
144-
image = cv2.imread(str(image_path))
145-
if image is None:
146-
raise ValueError(f"Failed to load image: {image_path}")
147-
# Convert BGR to RGB
148-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
149-
except ImportError:
150-
raise ImportError(
151-
"Either PIL or OpenCV is required to load images from files. "
152-
"Install with: pip install pillow or pip install opencv-python"
153-
)
154-
155-
elif HAS_PIL and isinstance(image, PILImage.Image):
156-
# Convert PIL Image to numpy array
157-
if image.mode != "RGB":
158-
image = image.convert("RGB")
159-
image = np.array(image, dtype=np.uint8)
160-
161-
elif isinstance(image, np.ndarray):
162-
# Validate numpy array
163-
if image.ndim != 3:
164-
raise ValueError(
165-
f"Image array must be 3-dimensional (H, W, C), got shape {image.shape}"
166-
)
167-
if image.shape[2] not in [3, 4]:
168-
raise ValueError(
169-
f"Image must have 3 (RGB) or 4 (RGBA) channels, got {image.shape[2]}"
170-
)
171-
if image.dtype != np.uint8:
172-
# Convert to uint8 if necessary
173-
if image.max() <= 1.0:
174-
# Assume normalized [0, 1] range
175-
image = (image * 255).astype(np.uint8)
176-
else:
177-
image = image.astype(np.uint8)
178-
else:
179-
raise ValueError(f"Unsupported image type: {type(image)}")
180-
181-
return make_image_input(image)
182-
183-
def generate(
184-
self,
185-
inputs: List[Any],
186-
config: Optional[GenerationConfig] = None,
187-
token_callback: Optional[Callable[[str], None]] = None,
188-
stats_callback: Optional[Callable[[Any], None]] = None,
189-
):
190-
"""
191-
Generate text from multimodal inputs with streaming callbacks.
192-
193-
Args:
194-
inputs: List of multimodal inputs (text, images, etc.)
195-
config: Generation configuration (uses defaults if None)
196-
token_callback: Function called for each generated token
197-
stats_callback: Function called with generation statistics
198-
"""
199-
if config is None:
200-
config = GenerationConfig()
201-
config.temperature = self._default_temperature
202-
203-
self._runner.generate(inputs, config, token_callback, stats_callback)
204-
205-
def generate_text(
206-
self,
207-
inputs: List[Any],
208-
config: Optional[GenerationConfig] = None,
209-
max_new_tokens: Optional[int] = None,
210-
temperature: Optional[float] = None,
211-
top_p: Optional[float] = None,
212-
**kwargs,
213-
) -> str:
214-
"""
215-
Generate text from multimodal inputs and return the complete result.
216-
217-
Args:
218-
inputs: List of multimodal inputs (text, images, etc.)
219-
config: Generation configuration (overrides other parameters if provided)
220-
max_new_tokens: Maximum number of tokens to generate
221-
temperature: Sampling temperature (0.0 to 1.0)
222-
top_p: Top-p sampling parameter
223-
**kwargs: Additional generation parameters
224-
225-
Returns:
226-
The generated text as a string
227-
"""
228-
if config is None:
229-
config = GenerationConfig()
230-
config.temperature = temperature or self._default_temperature
231-
if max_new_tokens is not None:
232-
config.max_new_tokens = max_new_tokens
233-
if top_p is not None:
234-
config.top_p = top_p
235-
236-
# Set any additional parameters
237-
for key, value in kwargs.items():
238-
if hasattr(config, key):
239-
setattr(config, key, value)
240-
241-
return self._runner.generate_text(inputs, config) # type: ignore[attr-defined]
242-
243-
def stop(self):
244-
"""Stop the current generation process."""
245-
self._runner.stop()
246-
247-
@property
248-
def vocab_size(self) -> int:
249-
"""Get the vocabulary size of the model."""
250-
return self._runner.get_vocab_size()
251-
252-
@property
253-
def model_path(self) -> str:
254-
"""Get the path to the loaded model."""
255-
return self._model_path
256-
257-
@property
258-
def tokenizer_path(self) -> str:
259-
"""Get the path to the loaded tokenizer."""
260-
return self._tokenizer_path
261-
262-
def __repr__(self) -> str:
263-
return (
264-
f"MultimodalRunner(model='{Path(self._model_path).name}', "
265-
f"tokenizer='{Path(self._tokenizer_path).name}', "
266-
f"vocab_size={self.vocab_size})"
267-
)
268-
269-
def __enter__(self):
270-
"""Context manager entry."""
271-
return self
272-
273-
def __exit__(self, exc_type, exc_val, exc_tb):
274-
"""Context manager exit - ensures cleanup."""
275-
self.stop()
276-
return False
277-
278-
27946
# Import utility functions
28047
from .utils import create_generation_config, load_image_from_file, preprocess_image
28148

@@ -285,7 +52,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
28552
"Stats",
28653
"Image",
28754
"MultimodalInput",
55+
"make_audio_input",
56+
"make_raw_audio_input",
28857
"make_text_input",
58+
"make_token_input",
28959
"make_image_input",
29060
"load_image_from_file",
29161
"preprocess_image",

0 commit comments

Comments
 (0)