Skip to content

Commit 14d2db7

Browse files
committed
Fix
1 parent 1ac3ac1 commit 14d2db7

File tree

3 files changed

+6
-93
lines changed

3 files changed

+6
-93
lines changed

extension/llm/runner/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
try:
2727
# Import shared components from the compiled C++ extension
28-
from ._llm_runner import (
28+
from executorch.extension.llm.runner._llm_runner import ( # noqa: F401
2929
GenerationConfig,
3030
Image,
3131
make_image_input,
@@ -105,7 +105,9 @@ def create_text_input(self, text: str):
105105
"""
106106
return make_text_input(text)
107107

108-
def create_image_input(self, image: Union[str, Path, np.ndarray, "PILImage.Image"]):
108+
def create_image_input( # noqa: C901
109+
self, image: Union[str, Path, np.ndarray, "PILImage.Image"]
110+
):
109111
"""
110112
Create an image input for multimodal processing.
111113

extension/llm/runner/test/test_pybindings.py

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
make_text_input,
2525
MultimodalInput,
2626
MultimodalRunner,
27-
Stats,
2827
)
2928

3029

@@ -114,94 +113,6 @@ def test_repr(self):
114113
self.assertIn("warming=False", repr_str)
115114

116115

117-
class TestStats(unittest.TestCase):
118-
"""Test the Stats class."""
119-
120-
def test_attributes(self):
121-
"""Test that Stats has all expected attributes."""
122-
stats = Stats()
123-
124-
# Check all timing attributes exist
125-
self.assertTrue(hasattr(stats, "SCALING_FACTOR_UNITS_PER_SECOND"))
126-
self.assertTrue(hasattr(stats, "model_load_start_ms"))
127-
self.assertTrue(hasattr(stats, "model_load_end_ms"))
128-
self.assertTrue(hasattr(stats, "inference_start_ms"))
129-
self.assertTrue(hasattr(stats, "token_encode_end_ms"))
130-
self.assertTrue(hasattr(stats, "model_execution_start_ms"))
131-
self.assertTrue(hasattr(stats, "model_execution_end_ms"))
132-
self.assertTrue(hasattr(stats, "prompt_eval_end_ms"))
133-
self.assertTrue(hasattr(stats, "first_token_ms"))
134-
self.assertTrue(hasattr(stats, "inference_end_ms"))
135-
self.assertTrue(hasattr(stats, "aggregate_sampling_time_ms"))
136-
self.assertTrue(hasattr(stats, "num_prompt_tokens"))
137-
self.assertTrue(hasattr(stats, "num_generated_tokens"))
138-
139-
def test_scaling_factor(self):
140-
"""Test the scaling factor constant."""
141-
stats = Stats()
142-
self.assertEqual(stats.SCALING_FACTOR_UNITS_PER_SECOND, 1000)
143-
144-
def test_methods(self):
145-
"""Test Stats methods."""
146-
stats = Stats()
147-
148-
# Test on_sampling_begin and on_sampling_end
149-
stats.on_sampling_begin()
150-
stats.on_sampling_end()
151-
152-
# Test reset without all_stats
153-
stats.model_load_start_ms = 100
154-
stats.model_load_end_ms = 200
155-
stats.inference_start_ms = 300
156-
stats.num_prompt_tokens = 10
157-
stats.num_generated_tokens = 20
158-
159-
stats.reset(False)
160-
161-
# Model load times should be preserved
162-
self.assertEqual(stats.model_load_start_ms, 100)
163-
self.assertEqual(stats.model_load_end_ms, 200)
164-
# Other stats should be reset
165-
self.assertEqual(stats.inference_start_ms, 0)
166-
self.assertEqual(stats.num_prompt_tokens, 0)
167-
self.assertEqual(stats.num_generated_tokens, 0)
168-
169-
# Test reset with all_stats
170-
stats.reset(True)
171-
self.assertEqual(stats.model_load_start_ms, 0)
172-
self.assertEqual(stats.model_load_end_ms, 0)
173-
174-
def test_to_json_string(self):
175-
"""Test JSON string conversion."""
176-
stats = Stats()
177-
stats.num_prompt_tokens = 10
178-
stats.num_generated_tokens = 20
179-
stats.model_load_start_ms = 100
180-
stats.model_load_end_ms = 200
181-
stats.inference_start_ms = 300
182-
stats.inference_end_ms = 1300
183-
184-
json_str = stats.to_json_string()
185-
self.assertIn('"prompt_tokens":10', json_str)
186-
self.assertIn('"generated_tokens":20', json_str)
187-
self.assertIn('"model_load_start_ms":100', json_str)
188-
self.assertIn('"model_load_end_ms":200', json_str)
189-
190-
def test_repr(self):
191-
"""Test string representation."""
192-
stats = Stats()
193-
stats.num_prompt_tokens = 10
194-
stats.num_generated_tokens = 20
195-
stats.inference_start_ms = 1000
196-
stats.inference_end_ms = 2000
197-
198-
repr_str = repr(stats)
199-
self.assertIn("Stats", repr_str)
200-
self.assertIn("num_prompt_tokens=10", repr_str)
201-
self.assertIn("num_generated_tokens=20", repr_str)
202-
self.assertIn("tokens_per_second=20", repr_str) # 20 tokens / 1 second
203-
204-
205116
class TestImage(unittest.TestCase):
206117
"""Test the Image class."""
207118

@@ -329,7 +240,7 @@ def tearDown(self):
329240
def test_initialization_failure(self):
330241
"""Test that initialization fails gracefully with invalid files."""
331242
with self.assertRaises(RuntimeError) as cm:
332-
runner = MultimodalRunner(self.model_path, self.tokenizer_path)
243+
MultimodalRunner(self.model_path, self.tokenizer_path, None)
333244
# Should fail because the tokenizer file is not valid
334245
self.assertIn("Failed to", str(cm.exception))
335246

extension/llm/runner/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
except ImportError:
2424
HAS_PIL = False
2525

26-
from ._llm_runner import GenerationConfig
26+
from executorch.extension.llm.runner._llm_runner import GenerationConfig # noqa: F401
2727

2828

2929
def load_image_from_file(

0 commit comments

Comments
 (0)