Skip to content

Commit ac34be9

Browse files
yossiovadiaAias00
authored andcommitted
Fix/improve batch classification test (vllm-project#319)
* feat: improve batch classification test to validate accuracy Previously, the batch classification test only validated HTTP status and result count, but never checked if the classifications were correct. The expected_categories variable was created but never used for validation. Changes: - Extract actual categories from batch classification results - Compare against expected categories and calculate accuracy percentage - Add detailed output showing each classification result - Assert that accuracy meets 75% threshold - Maintain backward compatibility with existing HTTP/count checks This improved test now properly catches classification accuracy issues and will fail when the classification system returns incorrect results, exposing problems that were previously hidden. Related to issue vllm-project#318: Batch Classification API Returns Incorrect Categories Signed-off-by: Yossi Ovadia <[email protected]> * style: apply black formatting to classification test Automatic formatting applied by black pre-commit hook. Signed-off-by: Yossi Ovadia <[email protected]> --------- Signed-off-by: Yossi Ovadia <[email protected]> Signed-off-by: liuhy <[email protected]>
1 parent 55896dd commit ac34be9

File tree

1 file changed

+56
-5
lines changed

1 file changed

+56
-5
lines changed

e2e-tests/03-classification-api-test.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,29 +189,80 @@ def test_batch_classification(self):
189189
response_json = response.json()
190190
results = response_json.get("results", [])
191191

192+
# Extract actual categories from results
193+
actual_categories = []
194+
correct_classifications = 0
195+
196+
for i, result in enumerate(results):
197+
if isinstance(result, dict):
198+
actual_category = result.get("category", "unknown")
199+
else:
200+
actual_category = "unknown"
201+
202+
actual_categories.append(actual_category)
203+
204+
if (
205+
i < len(expected_categories)
206+
and actual_category == expected_categories[i]
207+
):
208+
correct_classifications += 1
209+
210+
# Calculate accuracy
211+
accuracy = (
212+
(correct_classifications / len(expected_categories)) * 100
213+
if expected_categories
214+
else 0
215+
)
216+
192217
self.print_response_info(
193218
response,
194219
{
195220
"Total Texts": len(texts),
196221
"Results Count": len(results),
197222
"Processing Time (ms)": response_json.get("processing_time_ms", 0),
223+
"Accuracy": f"{accuracy:.1f}% ({correct_classifications}/{len(expected_categories)})",
198224
},
199225
)
200226

201-
passed = response.status_code == 200 and len(results) == len(texts)
227+
# Print detailed classification results
228+
print("\n📊 Detailed Classification Results:")
229+
for i, (text, expected, actual) in enumerate(
230+
zip(texts, expected_categories, actual_categories)
231+
):
232+
status = "✅" if expected == actual else "❌"
233+
print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}")
234+
print(f" Text: {text[:60]}...")
235+
236+
# Check basic requirements first
237+
basic_checks_passed = response.status_code == 200 and len(results) == len(texts)
238+
239+
# Check classification accuracy (should be high for a working system)
240+
accuracy_threshold = 75.0 # Expect at least 75% accuracy
241+
accuracy_passed = accuracy >= accuracy_threshold
242+
243+
overall_passed = basic_checks_passed and accuracy_passed
202244

203245
self.print_test_result(
204-
passed=passed,
246+
passed=overall_passed,
205247
message=(
206-
f"Successfully classified {len(results)} texts"
207-
if passed
208-
else f"Batch classification failed or returned wrong count"
248+
f"Successfully classified {len(results)} texts with {accuracy:.1f}% accuracy"
249+
if overall_passed
250+
else f"Batch classification issues: Basic checks: {basic_checks_passed}, Accuracy: {accuracy:.1f}% (threshold: {accuracy_threshold}%)"
209251
),
210252
)
211253

254+
# Basic checks
212255
self.assertEqual(response.status_code, 200, "Batch request failed")
213256
self.assertEqual(len(results), len(texts), "Result count mismatch")
214257

258+
# NEW: Validate classification accuracy
259+
self.assertGreaterEqual(
260+
accuracy,
261+
accuracy_threshold,
262+
f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. "
263+
f"Expected: {expected_categories}, Actual: {actual_categories}",
264+
)
265+
215266

216267
if __name__ == "__main__":
217268
unittest.main()

0 commit comments

Comments
 (0)