Skip to content

Commit b6d540d

Browse files
committed
More fixes?
1 parent 7ba7d88 commit b6d540d

File tree

4 files changed

+67
-45
lines changed

4 files changed

+67
-45
lines changed

.ci/scripts/test_huggingface_optimum_model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,13 @@ def test_llm_with_image_modality(
230230
return_tensors="pt",
231231
)
232232

233-
import torch
234-
235-
from executorch.extension.llm.runner import (
236-
GenerationConfig,
237-
make_image_input,
238-
make_token_input,
239-
MultimodalRunner,
240-
)
233+
from executorch.extension.llm.runner import GenerationConfig, MultimodalRunner
241234

242235
runner = MultimodalRunner(f"{model_dir}/model.pte", f"{model_dir}/tokenizer.model")
243236
generated_text = runner.generate_text_hf(
244-
inputs, GenerationConfig(max_new_tokens=128, temperature=0, echo=False), processor.image_token_id
237+
inputs,
238+
GenerationConfig(max_new_tokens=128, temperature=0, echo=False),
239+
processor.image_token_id,
245240
)
246241
print(f"\nGenerated text:\n\t{generated_text}")
247242
# Free memory before loading eager for quality check

extension/llm/runner/__init__.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232

3333

3434
import logging
35-
from typing import Any, Callable, Dict, List, Optional, Union
35+
from typing import Callable, List, Optional, Union
3636

3737
import torch
38+
from transformers.feature_extraction_utils import BatchFeature
3839

3940

4041
def _find_image_token_runs(
@@ -65,13 +66,13 @@ def _find_image_token_runs(
6566

6667

6768
def _hf_to_multimodal_inputs( # noqa: C901
68-
inputs: Dict[str, Any], image_token_id: Optional[int] = None
69+
inputs: BatchFeature, image_token_id: Optional[int] = None
6970
) -> List[MultimodalInput]:
7071
"""Convert a HuggingFace AutoProcessor dict to ExecuTorch MultimodalInputs.
7172
Currently only support 1 image inside the input.
7273
7374
Args:
74-
- inputs: A dictionary containing the input data.
75+
- inputs: A BatchFeature containing the input data.
7576
- image_token_id: The token ID for the image, if present.
7677
7778
`inputs` expected keys:
@@ -168,38 +169,50 @@ def _hf_to_multimodal_inputs( # noqa: C901
168169

169170
def generate_hf(
170171
runner: MultimodalRunner,
171-
inputs: Union[Dict[str, Any], List[MultimodalInput]],
172+
inputs: Union[BatchFeature, List[MultimodalInput]],
172173
config: GenerationConfig,
173174
image_token_id: Optional[int] = None,
174175
token_callback: Optional[Callable[[str], None]] = None,
175176
stats_callback: Optional[Callable[[Stats], None]] = None,
176177
) -> None:
177-
"""Generate using an HF dict by converting to multimodal inputs internally, or using a list of MultimodalInput."""
178-
if isinstance(inputs, dict):
178+
"""Generate using an BatchFeature by converting to multimodal inputs internally, or using a list of MultimodalInput."""
179+
if isinstance(inputs, BatchFeature):
179180
logging.info(
180-
"Input is a dict, assuming it's coming from HF AutoProcessor.apply_chat_template(). Converting to multimodal inputs."
181+
"Input is a BatchFeature, assuming it's coming from HF AutoProcessor.apply_chat_template(). Converting to multimodal inputs."
181182
)
182183
converted = _hf_to_multimodal_inputs(inputs, image_token_id=image_token_id)
183-
else:
184+
elif isinstance(inputs, list) and all(
185+
isinstance(i, MultimodalInput) for i in inputs
186+
):
184187
converted = inputs
188+
else:
189+
raise RuntimeError(
190+
"inputs must be either a BatchFeature (from HF AutoProcessor) or a list of MultimodalInput"
191+
)
185192

186193
runner.generate(converted, config, token_callback, stats_callback)
187194

188195

189196
def generate_text_hf(
190197
runner: MultimodalRunner,
191-
inputs: Union[Dict[str, Any], List[MultimodalInput]],
198+
inputs: Union[BatchFeature, List[MultimodalInput]],
192199
config: GenerationConfig,
193200
image_token_id: Optional[int] = None,
194201
) -> str:
195-
"""Generate using an HF dict by converting to multimodal inputs internally, or using a list of MultimodalInput."""
196-
if isinstance(inputs, dict):
202+
"""Generate using an BatchFeature by converting to multimodal inputs internally, or using a list of MultimodalInput."""
203+
if isinstance(inputs, BatchFeature):
197204
logging.info(
198-
"Input is a dict, assuming it's coming from HF AutoProcessor.apply_chat_template(). Converting to multimodal inputs."
205+
"Input is a BatchFeature, assuming it's coming from HF AutoProcessor.apply_chat_template(). Converting to multimodal inputs."
199206
)
200207
converted = _hf_to_multimodal_inputs(inputs, image_token_id=image_token_id)
201-
else:
208+
elif isinstance(inputs, list) and all(
209+
isinstance(i, MultimodalInput) for i in inputs
210+
):
202211
converted = inputs
212+
else:
213+
raise RuntimeError(
214+
"inputs must be either a BatchFeature (from HF AutoProcessor) or a list of MultimodalInput"
215+
)
203216

204217
return runner.generate_text(converted, config)
205218

extension/llm/runner/_llm_runner.pyi

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Type stubs for _llm_runner module.
44
This file provides type annotations for the ExecuTorch LLM Runner Python bindings.
55
"""
66

7-
from typing import Callable, List, Optional, Union
7+
from typing import Callable, List, Optional, Union, overload
88

99
import numpy as np
1010
import torch
@@ -134,14 +134,17 @@ class Stats:
134134
class Image:
135135
"""Container for image data."""
136136

137+
@overload
137138
def __init__(self) -> None:
138139
"""Initialize an empty Image."""
139140
...
140141

142+
@overload
141143
def __init__(self, data: List[int], width: int, height: int, channels: int) -> None:
142144
"""Initialize an Image with uint8 data."""
143145
...
144146

147+
@overload
145148
def __init__(
146149
self, data: List[float], width: int, height: int, channels: int
147150
) -> None:
@@ -198,10 +201,12 @@ class Audio:
198201
n_frames: int
199202
"""Number of time frames."""
200203

204+
@overload
201205
def __init__(self) -> None:
202206
"""Initialize an empty Audio."""
203207
...
204208

209+
@overload
205210
def __init__(
206211
self, data: List[int], batch_size: int, n_bins: int, n_frames: int
207212
) -> None:
@@ -225,10 +230,12 @@ class RawAudio:
225230
n_samples: int
226231
"""Number of audio samples."""
227232

233+
@overload
228234
def __init__(self) -> None:
229235
"""Initialize an empty RawAudio."""
230236
...
231237

238+
@overload
232239
def __init__(
233240
self, data: List[int], batch_size: int, n_channels: int, n_samples: int
234241
) -> None:
@@ -240,6 +247,7 @@ class RawAudio:
240247
class MultimodalInput:
241248
"""Container for multimodal input data (text, image, audio, etc.)."""
242249

250+
@overload
243251
def __init__(self, text: str) -> None:
244252
"""
245253
Create a MultimodalInput with text.
@@ -249,6 +257,7 @@ class MultimodalInput:
249257
"""
250258
...
251259

260+
@overload
252261
def __init__(self, image: Image) -> None:
253262
"""
254263
Create a MultimodalInput with an image.
@@ -258,6 +267,7 @@ class MultimodalInput:
258267
"""
259268
...
260269

270+
@overload
261271
def __init__(self, audio: Audio) -> None:
262272
"""
263273
Create a MultimodalInput with preprocessed audio.
@@ -267,6 +277,7 @@ class MultimodalInput:
267277
"""
268278
...
269279

280+
@overload
270281
def __init__(self, raw_audio: RawAudio) -> None:
271282
"""
272283
Create a MultimodalInput with raw audio.
@@ -347,6 +358,7 @@ class MultimodalRunner:
347358
RuntimeError: If initialization fails
348359
"""
349360
...
361+
350362
def generate(
351363
self,
352364
inputs: List[MultimodalInput],
@@ -366,7 +378,7 @@ class MultimodalRunner:
366378
Raises:
367379
RuntimeError: If generation fails
368380
"""
369-
...
381+
...
370382

371383
def generate_hf(
372384
self,
@@ -424,9 +436,11 @@ class MultimodalRunner:
424436
Raises:
425437
RuntimeError: If generation fails
426438
"""
427-
...
439+
...
428440

429-
def generate_text_hf(self, inputs: dict, config: GenerationConfig, image_token_id) -> str:
441+
def generate_text_hf(
442+
self, inputs: dict, config: GenerationConfig, image_token_id
443+
) -> str:
430444
"""
431445
Generate text directly from a HuggingFace processor dict and return as string.
432446

extension/llm/runner/pybindings.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,15 @@ PYBIND11_MODULE(_llm_runner, m) {
173173
float temperature,
174174
int32_t num_bos,
175175
int32_t num_eos) {
176-
GenerationConfig cfg;
177-
cfg.echo = echo;
178-
cfg.max_new_tokens = max_new_tokens;
179-
cfg.warming = warming;
180-
cfg.seq_len = seq_len;
181-
cfg.temperature = temperature;
182-
cfg.num_bos = num_bos;
183-
cfg.num_eos = num_eos;
184-
return cfg;
176+
GenerationConfig cfg;
177+
cfg.echo = echo;
178+
cfg.max_new_tokens = max_new_tokens;
179+
cfg.warming = warming;
180+
cfg.seq_len = seq_len;
181+
cfg.temperature = temperature;
182+
cfg.num_bos = num_bos;
183+
cfg.num_eos = num_eos;
184+
return cfg;
185185
}),
186186
py::arg("echo") = true,
187187
py::arg("max_new_tokens") = -1,
@@ -204,12 +204,12 @@ PYBIND11_MODULE(_llm_runner, m) {
204204
py::arg("num_prompt_tokens"),
205205
"Resolve the maximum number of new tokens to generate based on constraints")
206206
.def("__repr__", [](const GenerationConfig& config) {
207-
return "<GenerationConfig max_new_tokens=" +
208-
std::to_string(config.max_new_tokens) +
209-
" seq_len=" + std::to_string(config.seq_len) +
210-
" temperature=" + std::to_string(config.temperature) +
211-
" echo=" + (config.echo ? "True" : "False") +
212-
" warming=" + (config.warming ? "True" : "False") + ">";
207+
return "<GenerationConfig max_new_tokens=" +
208+
std::to_string(config.max_new_tokens) +
209+
" seq_len=" + std::to_string(config.seq_len) +
210+
" temperature=" + std::to_string(config.temperature) +
211+
" echo=" + (config.echo ? "True" : "False") +
212+
" warming=" + (config.warming ? "True" : "False") + ">";
213213
});
214214

215215
// Bind Stats
@@ -365,10 +365,10 @@ PYBIND11_MODULE(_llm_runner, m) {
365365
py::init<const std::string&>(),
366366
py::arg("text"),
367367
"Create a MultimodalInput with text")
368-
.def(
369-
py::init<const std::vector<uint64_t>&>(),
370-
py::arg("tokens"),
371-
"Create a MultimodalInput with pre-tokenized tokens (List[int])")
368+
.def(
369+
py::init<const std::vector<uint64_t>&>(),
370+
py::arg("tokens"),
371+
"Create a MultimodalInput with pre-tokenized tokens (List[int])")
372372
.def(
373373
py::init<const std::vector<uint64_t>&>(),
374374
py::arg("tokens"),

0 commit comments

Comments
 (0)