|
24 | 24 | make_text_input, |
25 | 25 | MultimodalInput, |
26 | 26 | MultimodalRunner, |
27 | | - Stats, |
28 | 27 | ) |
29 | 28 |
|
30 | 29 |
|
@@ -114,94 +113,6 @@ def test_repr(self): |
114 | 113 | self.assertIn("warming=False", repr_str) |
115 | 114 |
|
116 | 115 |
|
117 | | -class TestStats(unittest.TestCase): |
118 | | - """Test the Stats class.""" |
119 | | - |
120 | | - def test_attributes(self): |
121 | | - """Test that Stats has all expected attributes.""" |
122 | | - stats = Stats() |
123 | | - |
124 | | - # Check all timing attributes exist |
125 | | - self.assertTrue(hasattr(stats, "SCALING_FACTOR_UNITS_PER_SECOND")) |
126 | | - self.assertTrue(hasattr(stats, "model_load_start_ms")) |
127 | | - self.assertTrue(hasattr(stats, "model_load_end_ms")) |
128 | | - self.assertTrue(hasattr(stats, "inference_start_ms")) |
129 | | - self.assertTrue(hasattr(stats, "token_encode_end_ms")) |
130 | | - self.assertTrue(hasattr(stats, "model_execution_start_ms")) |
131 | | - self.assertTrue(hasattr(stats, "model_execution_end_ms")) |
132 | | - self.assertTrue(hasattr(stats, "prompt_eval_end_ms")) |
133 | | - self.assertTrue(hasattr(stats, "first_token_ms")) |
134 | | - self.assertTrue(hasattr(stats, "inference_end_ms")) |
135 | | - self.assertTrue(hasattr(stats, "aggregate_sampling_time_ms")) |
136 | | - self.assertTrue(hasattr(stats, "num_prompt_tokens")) |
137 | | - self.assertTrue(hasattr(stats, "num_generated_tokens")) |
138 | | - |
139 | | - def test_scaling_factor(self): |
140 | | - """Test the scaling factor constant.""" |
141 | | - stats = Stats() |
142 | | - self.assertEqual(stats.SCALING_FACTOR_UNITS_PER_SECOND, 1000) |
143 | | - |
144 | | - def test_methods(self): |
145 | | - """Test Stats methods.""" |
146 | | - stats = Stats() |
147 | | - |
148 | | - # Test on_sampling_begin and on_sampling_end |
149 | | - stats.on_sampling_begin() |
150 | | - stats.on_sampling_end() |
151 | | - |
152 | | - # Test reset without all_stats |
153 | | - stats.model_load_start_ms = 100 |
154 | | - stats.model_load_end_ms = 200 |
155 | | - stats.inference_start_ms = 300 |
156 | | - stats.num_prompt_tokens = 10 |
157 | | - stats.num_generated_tokens = 20 |
158 | | - |
159 | | - stats.reset(False) |
160 | | - |
161 | | - # Model load times should be preserved |
162 | | - self.assertEqual(stats.model_load_start_ms, 100) |
163 | | - self.assertEqual(stats.model_load_end_ms, 200) |
164 | | - # Other stats should be reset |
165 | | - self.assertEqual(stats.inference_start_ms, 0) |
166 | | - self.assertEqual(stats.num_prompt_tokens, 0) |
167 | | - self.assertEqual(stats.num_generated_tokens, 0) |
168 | | - |
169 | | - # Test reset with all_stats |
170 | | - stats.reset(True) |
171 | | - self.assertEqual(stats.model_load_start_ms, 0) |
172 | | - self.assertEqual(stats.model_load_end_ms, 0) |
173 | | - |
174 | | - def test_to_json_string(self): |
175 | | - """Test JSON string conversion.""" |
176 | | - stats = Stats() |
177 | | - stats.num_prompt_tokens = 10 |
178 | | - stats.num_generated_tokens = 20 |
179 | | - stats.model_load_start_ms = 100 |
180 | | - stats.model_load_end_ms = 200 |
181 | | - stats.inference_start_ms = 300 |
182 | | - stats.inference_end_ms = 1300 |
183 | | - |
184 | | - json_str = stats.to_json_string() |
185 | | - self.assertIn('"prompt_tokens":10', json_str) |
186 | | - self.assertIn('"generated_tokens":20', json_str) |
187 | | - self.assertIn('"model_load_start_ms":100', json_str) |
188 | | - self.assertIn('"model_load_end_ms":200', json_str) |
189 | | - |
190 | | - def test_repr(self): |
191 | | - """Test string representation.""" |
192 | | - stats = Stats() |
193 | | - stats.num_prompt_tokens = 10 |
194 | | - stats.num_generated_tokens = 20 |
195 | | - stats.inference_start_ms = 1000 |
196 | | - stats.inference_end_ms = 2000 |
197 | | - |
198 | | - repr_str = repr(stats) |
199 | | - self.assertIn("Stats", repr_str) |
200 | | - self.assertIn("num_prompt_tokens=10", repr_str) |
201 | | - self.assertIn("num_generated_tokens=20", repr_str) |
202 | | - self.assertIn("tokens_per_second=20", repr_str) # 20 tokens / 1 second |
203 | | - |
204 | | - |
205 | 116 | class TestImage(unittest.TestCase): |
206 | 117 | """Test the Image class.""" |
207 | 118 |
|
@@ -329,7 +240,7 @@ def tearDown(self): |
329 | 240 | def test_initialization_failure(self): |
330 | 241 | """Test that initialization fails gracefully with invalid files.""" |
331 | 242 | with self.assertRaises(RuntimeError) as cm: |
332 | | - runner = MultimodalRunner(self.model_path, self.tokenizer_path) |
| 243 | + MultimodalRunner(self.model_path, self.tokenizer_path, None) |
333 | 244 | # Should fail because the tokenizer file is not valid |
334 | 245 | self.assertIn("Failed to", str(cm.exception)) |
335 | 246 |
|
|
0 commit comments