|
| 1 | +import asyncio |
| 2 | +import time |
1 | 3 | from collections.abc import Iterable
|
2 | 4 | from dataclasses import dataclass
|
3 | 5 | from typing import Any, cast
|
@@ -31,15 +33,23 @@ def __init__(self, model_name: str = "default") -> None:
|
31 | 33 |
|
32 | 34 |
|
33 | 35 | class MockEvaluationPipeline(EvaluationPipeline[MockEvaluationTarget, MockEvaluationData, MockEvaluationResult]):
|
| 36 | + def __init__(self, evaluation_target: MockEvaluationTarget, slow: bool = False): |
| 37 | + super().__init__(evaluation_target) |
| 38 | + self._slow = slow |
| 39 | + |
34 | 40 | async def __call__(self, data: Iterable[MockEvaluationData]) -> Iterable[MockEvaluationResult]:
|
35 |
| - return [ |
36 |
| - MockEvaluationResult( |
37 |
| - input_data=row.input_data, |
38 |
| - processed_output=f"{self.evaluation_target.model_name}_{row.input_data}", |
39 |
| - is_correct=row.input_data % 2 == 0, |
| 41 | + results = [] |
| 42 | + for row in data: |
| 43 | + if self._slow: |
| 44 | + await asyncio.sleep(0.5) |
| 45 | + results.append( |
| 46 | + MockEvaluationResult( |
| 47 | + input_data=row.input_data, |
| 48 | + processed_output=f"{self.evaluation_target.model_name}_{row.input_data}", |
| 49 | + is_correct=row.input_data % 2 == 0, |
| 50 | + ) |
40 | 51 | )
|
41 |
| - for row in data |
42 |
| - ] |
| 52 | + return results |
43 | 53 |
|
44 | 54 | @classmethod
|
45 | 55 | def from_config(cls, config: dict) -> "MockEvaluationPipeline":
|
@@ -102,6 +112,66 @@ async def test_run_evaluation(
|
102 | 112 | assert all("test_model_" in r.processed_output for r in results.results)
|
103 | 113 |
|
104 | 114 |
|
| 115 | +@pytest.mark.parametrize( |
| 116 | + ("parallelize_batches", "expected_results", "expected_accuracy"), |
| 117 | + [(False, 4, 0.5), (True, 4, 0.5)], |
| 118 | +) |
| 119 | +async def test_run_evaluation_with_parallel_batches( |
| 120 | + parallelize_batches: bool, |
| 121 | + expected_results: int, |
| 122 | + expected_accuracy: float, |
| 123 | +) -> None: |
| 124 | + target = MockEvaluationTarget(model_name="parallel_test_model") |
| 125 | + pipeline = MockEvaluationPipeline(target) |
| 126 | + dataloader = MockDataLoader() |
| 127 | + metrics = MetricSet(*[MockMetric()]) |
| 128 | + evaluator = Evaluator(batch_size=2, parallelize_batches=parallelize_batches) |
| 129 | + |
| 130 | + results = await evaluator.compute( |
| 131 | + pipeline=pipeline, |
| 132 | + dataloader=dataloader, |
| 133 | + metricset=metrics, |
| 134 | + ) |
| 135 | + |
| 136 | + assert len(results.results) == expected_results |
| 137 | + assert len(results.errors) == 0 |
| 138 | + assert results.metrics["accuracy"] == expected_accuracy |
| 139 | + assert all("parallel_test_model_" in r.processed_output for r in results.results) |
| 140 | + |
| 141 | + |
| 142 | +async def test_parallel_batches_performance() -> None: |
| 143 | + """Test that parallel processing is faster than sequential processing.""" |
| 144 | + target = MockEvaluationTarget(model_name="timing_test_model") |
| 145 | + pipeline = MockEvaluationPipeline(target, slow=True) |
| 146 | + dataloader = MockDataLoader(dataset_size=4) |
| 147 | + metrics = MetricSet(*[MockMetric()]) |
| 148 | + |
| 149 | + # Test sequential processing |
| 150 | + evaluator_sequential = Evaluator(batch_size=2, parallelize_batches=False) |
| 151 | + start_time = time.perf_counter() |
| 152 | + results_sequential = await evaluator_sequential.compute( |
| 153 | + pipeline=pipeline, |
| 154 | + dataloader=dataloader, |
| 155 | + metricset=metrics, |
| 156 | + ) |
| 157 | + sequential_time = time.perf_counter() - start_time |
| 158 | + |
| 159 | + evaluator_parallel = Evaluator(batch_size=2, parallelize_batches=True) |
| 160 | + start_time = time.perf_counter() |
| 161 | + results_parallel = await evaluator_parallel.compute( |
| 162 | + pipeline=pipeline, |
| 163 | + dataloader=dataloader, |
| 164 | + metricset=metrics, |
| 165 | + ) |
| 166 | + parallel_time = time.perf_counter() - start_time |
| 167 | + |
| 168 | + assert len(results_sequential.results) == len(results_parallel.results) |
| 169 | + assert results_sequential.metrics == results_parallel.metrics |
| 170 | + |
| 171 | + # Parallel processing should be roughly 2x faster, but we add some margin |
| 172 | + assert parallel_time < sequential_time * 0.7 |
| 173 | + |
| 174 | + |
105 | 175 | async def test_run_from_config() -> None:
|
106 | 176 | config = {
|
107 | 177 | "evaluation": {
|
|
0 commit comments