Skip to content

Commit a45b2af

Browse files
committed
Retry fixing CI
1 parent d7c9492 commit a45b2af

File tree

3 files changed

+11
-23
lines changed

3 files changed

+11
-23
lines changed

.ci/scripts/test_huggingface_optimum_model.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -232,30 +232,16 @@ def test_llm_with_image_modality(
232232

233233
import torch
234234

235-
first_image_id_index = torch.where(inputs["input_ids"] == processor.image_token_id)[
236-
1
237-
][0].item()
238-
last_image_id_index = torch.where(inputs["input_ids"] == processor.image_token_id)[
239-
1
240-
][-1].item()
241-
242-
prompt_before_image = inputs["input_ids"][0, :first_image_id_index]
243-
prompt_after_image = inputs["input_ids"][0, last_image_id_index + 1 :]
244235
from executorch.extension.llm.runner import (
245236
GenerationConfig,
246237
make_image_input,
247238
make_token_input,
248239
MultimodalRunner,
249240
)
250241

251-
combined_inputs = [
252-
make_token_input(prompt_before_image.tolist()),
253-
make_image_input(inputs["pixel_values"]),
254-
make_token_input(prompt_after_image.tolist()),
255-
]
256242
runner = MultimodalRunner(f"{model_dir}/model.pte", f"{model_dir}/tokenizer.model")
257-
generated_text = runner.generate_text(
258-
combined_inputs, GenerationConfig(max_new_tokens=128, temperature=0, echo=False)
243+
generated_text = runner.generate_text_hf(
244+
inputs, GenerationConfig(max_new_tokens=128, temperature=0, echo=False), processor.image_token_id
259245
)
260246
print(f"\nGenerated text:\n\t{generated_text}")
261247
# Free memory before loading eager for quality check

extension/llm/runner/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _hf_to_multimodal_inputs( # noqa: C901
166166
return combined
167167

168168

169-
def generate(
169+
def generate_hf(
170170
runner: MultimodalRunner,
171171
inputs: Union[Dict[str, Any], List[MultimodalInput]],
172172
config: GenerationConfig,
@@ -186,7 +186,7 @@ def generate(
186186
runner.generate(converted, config, token_callback, stats_callback)
187187

188188

189-
def generate_text(
189+
def generate_text_hf(
190190
runner: MultimodalRunner,
191191
inputs: Union[Dict[str, Any], List[MultimodalInput]],
192192
config: GenerationConfig,
@@ -204,8 +204,8 @@ def generate_text(
204204
return runner.generate_text(converted, config)
205205

206206

207-
setattr(MultimodalRunner, "generate", generate) # noqa B010
208-
setattr(MultimodalRunner, "generate_text", generate_text) # noqa B010
207+
setattr(MultimodalRunner, "generate_hf", generate_hf) # noqa B010
208+
setattr(MultimodalRunner, "generate_text_hf", generate_text_hf) # noqa B010
209209

210210

211211
__all__ = [

extension/llm/runner/_llm_runner.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,13 @@ class MultimodalRunner:
368368
"""
369369
...
370370

371-
def generate(
371+
def generate_hf(
372372
self,
373373
inputs: dict,
374374
config: GenerationConfig,
375375
token_callback: Optional[Callable[[str], None]] = None,
376376
stats_callback: Optional[Callable[[Stats], None]] = None,
377+
image_token_id: Optional[int] = None,
377378
) -> None:
378379
"""
379380
Generate text directly from a HuggingFace processor dict.
@@ -387,6 +388,7 @@ class MultimodalRunner:
387388
config: Generation configuration
388389
token_callback: Optional per-token callback
389390
stats_callback: Optional stats callback
391+
image_token_id: Optional image token ID (or index)
390392
391393
Raises:
392394
RuntimeError: If required keys are missing, shapes are invalid, or generation fails
@@ -424,11 +426,11 @@ class MultimodalRunner:
424426
"""
425427
...
426428

427-
def generate_text(self, inputs: dict, config: GenerationConfig) -> str:
429+
def generate_text_hf(self, inputs: dict, config: GenerationConfig, image_token_id) -> str:
428430
"""
429431
Generate text directly from a HuggingFace processor dict and return as string.
430432
431-
See generate(inputs: dict, ...) for expected keys and constraints.
433+
See generate_hf(inputs: dict, ...) for expected keys and constraints.
432434
"""
433435
...
434436

0 commit comments

Comments
 (0)