Skip to content

Commit 95202d2

Browse files
committed
Fix tests
1 parent 87923e6 commit 95202d2

File tree

2 files changed

+178
-241
lines changed

2 files changed

+178
-241
lines changed

extension/llm/runner/_llm_runner.pyi

Lines changed: 75 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,291 +4,283 @@ Type stubs for _llm_runner module.
44
This file provides type annotations for the ExecuTorch LLM Runner Python bindings.
55
"""
66

7-
from typing import List, Optional, Callable, Union
7+
from typing import Callable, List, Optional, Union
8+
89
import numpy as np
910
from numpy.typing import NDArray
1011

1112
class GenerationConfig:
1213
"""Configuration for text generation."""
13-
14+
1415
echo: bool
1516
"""Whether to echo the input prompt in the output."""
16-
17+
1718
max_new_tokens: int
1819
"""Maximum number of new tokens to generate (-1 for auto)."""
19-
20+
2021
warming: bool
2122
"""Whether this is a warmup run (affects perf benchmarking)."""
22-
23+
2324
seq_len: int
2425
"""Maximum number of total tokens (-1 for auto)."""
25-
26+
2627
temperature: float
2728
"""Temperature for sampling (higher = more random)."""
28-
29+
2930
num_bos: int
3031
"""Number of BOS tokens to add to the prompt."""
31-
32+
3233
num_eos: int
3334
"""Number of EOS tokens to add to the prompt."""
34-
35+
3536
def __init__(self) -> None:
3637
"""Initialize GenerationConfig with default values."""
3738
...
38-
39-
def resolve_max_new_tokens(self, max_context_len: int, num_prompt_tokens: int) -> int:
39+
40+
def resolve_max_new_tokens(
41+
self, max_context_len: int, num_prompt_tokens: int
42+
) -> int:
4043
"""
4144
Resolve the maximum number of new tokens to generate based on constraints.
42-
45+
4346
Args:
4447
max_context_len: The maximum context length supported by the model
4548
num_prompt_tokens: The number of tokens in the input prompt
46-
49+
4750
Returns:
4851
The resolved maximum number of new tokens to generate
4952
"""
5053
...
51-
52-
def __repr__(self) -> str: ...
5354

55+
def __repr__(self) -> str: ...
5456

5557
class Stats:
5658
"""Statistics for LLM generation performance."""
57-
59+
5860
SCALING_FACTOR_UNITS_PER_SECOND: int
5961
"""Scaling factor for timestamps (1000 for milliseconds)."""
60-
62+
6163
model_load_start_ms: int
6264
"""Start time of model loading in milliseconds."""
63-
65+
6466
model_load_end_ms: int
6567
"""End time of model loading in milliseconds."""
66-
68+
6769
inference_start_ms: int
6870
"""Start time of inference in milliseconds."""
69-
71+
7072
token_encode_end_ms: int
7173
"""End time of tokenizer encoding in milliseconds."""
72-
74+
7375
model_execution_start_ms: int
7476
"""Start time of model execution in milliseconds."""
75-
77+
7678
model_execution_end_ms: int
7779
"""End time of model execution in milliseconds."""
78-
80+
7981
prompt_eval_end_ms: int
8082
"""End time of prompt evaluation in milliseconds."""
81-
83+
8284
first_token_ms: int
8385
"""Timestamp when the first generated token is emitted."""
84-
86+
8587
inference_end_ms: int
8688
"""End time of inference/generation in milliseconds."""
87-
89+
8890
aggregate_sampling_time_ms: int
8991
"""Total time spent in sampling across all tokens."""
90-
92+
9193
num_prompt_tokens: int
9294
"""Number of tokens in the input prompt."""
93-
95+
9496
num_generated_tokens: int
9597
"""Number of tokens generated."""
96-
98+
9799
def on_sampling_begin(self) -> None:
98100
"""Mark the beginning of a sampling operation."""
99101
...
100-
102+
101103
def on_sampling_end(self) -> None:
102104
"""Mark the end of a sampling operation."""
103105
...
104-
106+
105107
def reset(self, all_stats: bool = False) -> None:
106108
"""
107109
Reset statistics.
108-
110+
109111
Args:
110112
all_stats: If True, reset all stats including model load times.
111113
If False, preserve model load times.
112114
"""
113115
...
114-
116+
115117
def to_json_string(self) -> str:
116118
"""Convert stats to JSON string representation."""
117119
...
118-
119-
def __repr__(self) -> str: ...
120120

121+
def __repr__(self) -> str: ...
121122

122123
class Image:
123124
"""Container for image data."""
124-
125+
125126
data: List[int]
126127
"""Raw image data as a list of uint8 values."""
127-
128+
128129
width: int
129130
"""Image width in pixels."""
130-
131+
131132
height: int
132133
"""Image height in pixels."""
133-
134+
134135
channels: int
135136
"""Number of color channels (3 for RGB, 4 for RGBA)."""
136-
137+
137138
def __init__(self) -> None:
138139
"""Initialize an empty Image."""
139140
...
140-
141-
def __repr__(self) -> str: ...
142141

142+
def __repr__(self) -> str: ...
143143

144144
class MultimodalInput:
145145
"""Container for multimodal input data (text, image, etc.)."""
146-
146+
147147
def __init__(self, text: str) -> None:
148148
"""
149149
Create a MultimodalInput with text.
150-
150+
151151
Args:
152152
text: The input text string
153153
"""
154154
...
155-
155+
156156
def __init__(self, image: Image) -> None:
157157
"""
158158
Create a MultimodalInput with an image.
159-
159+
160160
Args:
161161
image: The input image
162162
"""
163163
...
164-
164+
165165
def is_text(self) -> bool:
166166
"""Check if this input contains text."""
167167
...
168-
168+
169169
def is_image(self) -> bool:
170170
"""Check if this input contains an image."""
171171
...
172-
172+
173173
def get_text(self) -> Optional[str]:
174174
"""
175175
Get the text content if this is a text input.
176-
176+
177177
Returns:
178178
The text string if this is a text input, None otherwise
179179
"""
180180
...
181-
182-
def __repr__(self) -> str: ...
183181

182+
def __repr__(self) -> str: ...
184183

185184
class MultimodalRunner:
186185
"""Runner for multimodal language models."""
187-
186+
188187
def __init__(
189-
self,
190-
model_path: str,
191-
tokenizer_path: str,
192-
data_path: Optional[str] = None
188+
self, model_path: str, tokenizer_path: str, data_path: Optional[str] = None
193189
) -> None:
194190
"""
195191
Initialize a MultimodalRunner.
196-
192+
197193
Args:
198194
model_path: Path to the model file (.pte)
199195
tokenizer_path: Path to the tokenizer file
200196
data_path: Optional path to additional data file
201-
197+
202198
Raises:
203199
RuntimeError: If initialization fails
204200
"""
205201
...
206-
202+
207203
def generate(
208204
self,
209205
inputs: List[MultimodalInput],
210206
config: GenerationConfig,
211207
token_callback: Optional[Callable[[str], None]] = None,
212-
stats_callback: Optional[Callable[[Stats], None]] = None
208+
stats_callback: Optional[Callable[[Stats], None]] = None,
213209
) -> None:
214210
"""
215211
Generate text from multimodal inputs.
216-
212+
217213
Args:
218214
inputs: List of multimodal inputs (text, images, etc.)
219215
config: Generation configuration
220216
token_callback: Optional callback called for each generated token
221217
stats_callback: Optional callback called with generation statistics
222-
218+
223219
Raises:
224220
RuntimeError: If generation fails
225221
"""
226222
...
227-
223+
228224
def generate_text(
229-
self,
230-
inputs: List[MultimodalInput],
231-
config: GenerationConfig
225+
self, inputs: List[MultimodalInput], config: GenerationConfig
232226
) -> str:
233227
"""
234228
Generate text and return the complete result as a string.
235-
229+
236230
Args:
237231
inputs: List of multimodal inputs (text, images, etc.)
238232
config: Generation configuration
239-
233+
240234
Returns:
241235
The generated text as a string
242-
236+
243237
Raises:
244238
RuntimeError: If generation fails
245239
"""
246240
...
247-
241+
248242
def stop(self) -> None:
249243
"""Stop the current generation process."""
250244
...
251-
245+
252246
def reset(self) -> None:
253247
"""Reset the runner state and KV cache."""
254248
...
255-
249+
256250
def get_vocab_size(self) -> int:
257251
"""
258252
Get the vocabulary size of the model.
259-
253+
260254
Returns:
261255
The vocabulary size, or -1 if not available
262256
"""
263257
...
264-
265-
def __repr__(self) -> str: ...
266258

259+
def __repr__(self) -> str: ...
267260

268261
def make_text_input(text: str) -> MultimodalInput:
269262
"""
270263
Create a text input for multimodal processing.
271-
264+
272265
Args:
273266
text: The input text string
274-
267+
275268
Returns:
276269
A MultimodalInput containing the text
277270
"""
278271
...
279272

280-
281273
def make_image_input(image_array: NDArray[np.uint8]) -> MultimodalInput:
282274
"""
283275
Create an image input from a numpy array.
284-
276+
285277
Args:
286278
image_array: Numpy array with shape (H, W, C) where C is 3 (RGB) or 4 (RGBA)
287-
279+
288280
Returns:
289281
A MultimodalInput containing the image
290-
282+
291283
Raises:
292284
RuntimeError: If the array has invalid dimensions or number of channels
293285
"""
294-
...
286+
...

0 commit comments

Comments
 (0)