Skip to content

Commit bb1fcc8

Browse files
vladimir-kivi-dsGlockPLKonrad Czarnotaakotylajakubduda-dsai
authored andcommitted
feat: Optional parallel batches execution in ragbits.evaluate.Evaluator (#769)
Co-authored-by: GlockPL <[email protected]> Co-authored-by: Konrad Czarnota <[email protected]> Co-authored-by: GlockPL <[email protected]> Co-authored-by: akotyla <[email protected]> Co-authored-by: jakubduda-dsai <[email protected]> Co-authored-by: ds-sebastianchwilczynski <[email protected]> Co-authored-by: dazy-ds <[email protected]>
1 parent e15893e commit bb1fcc8

File tree

3 files changed

+106
-13
lines changed

3 files changed

+106
-13
lines changed

packages/ragbits-evaluate/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Optional parallel batches execution in ragbits.evaluate.Evaluator (#769)
6+
57
## 1.2.2 (2025-08-08)
68

79
### Changed
@@ -142,6 +144,7 @@
142144
- ragbits-core updated to version v0.10.1
143145

144146
## 0.10.0 (2025-03-17)
147+
145148
### Changed
146149

147150
- ragbits-core updated to version v0.10.0

packages/ragbits-evaluate/src/ragbits/evaluate/evaluator.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import random
33
import time
4-
from collections.abc import Awaitable, Callable, Iterable
4+
from collections.abc import Awaitable, Callable, Iterable, Sized
55
from dataclasses import dataclass
66
from typing import Generic, ParamSpec, TypeVar
77

@@ -71,6 +71,7 @@ def __init__(
7171
num_retries: int = 3,
7272
backoff_multiplier: int = 1,
7373
backoff_max: int = 60,
74+
parallelize_batches: bool = False,
7475
) -> None:
7576
"""
7677
Initialize the Evaluator instance.
@@ -80,11 +81,13 @@ def __init__(
8081
num_retries: The number of retries per evaluation pipeline inference error.
8182
backoff_multiplier: The base delay multiplier for exponential backoff (in seconds).
8283
backoff_max: The maximum allowed delay (in seconds) between retries.
84+
parallelize_batches: Whether to process samples within each batch in parallel (asyncio.gather).
8385
"""
8486
self.batch_size = batch_size
8587
self.num_retries = num_retries
8688
self.backoff_multiplier = backoff_multiplier
8789
self.backoff_max = backoff_max
90+
self.parallelize_batches = parallelize_batches
8891

8992
@classmethod
9093
async def run_from_config(cls, config: dict) -> EvaluatorResult:
@@ -156,16 +159,33 @@ async def _call_pipeline(
156159
The evaluation results and performance metrics.
157160
"""
158161
start_time = time.perf_counter()
159-
outputs = [
160-
await self._call_with_error_handling(pipeline, data)
161-
for data in tqdm(batched(dataset, self.batch_size), desc="Evaluation")
162-
]
162+
163+
total_samples = len(dataset) if isinstance(dataset, Sized) else None
164+
batches = batched(dataset, self.batch_size)
165+
outputs: list[Iterable[EvaluationResultT] | Exception] = []
166+
167+
with tqdm(total=total_samples, desc="Evaluation", unit="sample") as progress_bar:
168+
for batch in batches:
169+
batch_list = list(batch)
170+
171+
if self.parallelize_batches:
172+
tasks = [self._call_with_error_handling(pipeline, [sample]) for sample in batch_list]
173+
batch_results = await asyncio.gather(*tasks)
174+
175+
for result in batch_results:
176+
outputs.append(result)
177+
progress_bar.update(1)
178+
else:
179+
result = await self._call_with_error_handling(pipeline, batch_list)
180+
outputs.append(result)
181+
progress_bar.update(len(batch_list))
182+
163183
end_time = time.perf_counter()
164184

165185
errors = [output for output in outputs if isinstance(output, Exception)]
166186
results = [item for output in outputs if not isinstance(output, Exception) for item in output]
167187

168-
return results, errors, self._compute_time_perf(start_time, end_time, len(outputs))
188+
return results, errors, self._compute_time_perf(start_time, end_time, len(results))
169189

170190
async def _call_with_error_handling(
171191
self,

packages/ragbits-evaluate/tests/unit/test_evaluator.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import time
13
from collections.abc import Iterable
24
from dataclasses import dataclass
35
from typing import Any, cast
@@ -31,15 +33,23 @@ def __init__(self, model_name: str = "default") -> None:
3133

3234

3335
class MockEvaluationPipeline(EvaluationPipeline[MockEvaluationTarget, MockEvaluationData, MockEvaluationResult]):
36+
def __init__(self, evaluation_target: MockEvaluationTarget, slow: bool = False):
37+
super().__init__(evaluation_target)
38+
self._slow = slow
39+
3440
async def __call__(self, data: Iterable[MockEvaluationData]) -> Iterable[MockEvaluationResult]:
35-
return [
36-
MockEvaluationResult(
37-
input_data=row.input_data,
38-
processed_output=f"{self.evaluation_target.model_name}_{row.input_data}",
39-
is_correct=row.input_data % 2 == 0,
41+
results = []
42+
for row in data:
43+
if self._slow:
44+
await asyncio.sleep(0.5)
45+
results.append(
46+
MockEvaluationResult(
47+
input_data=row.input_data,
48+
processed_output=f"{self.evaluation_target.model_name}_{row.input_data}",
49+
is_correct=row.input_data % 2 == 0,
50+
)
4051
)
41-
for row in data
42-
]
52+
return results
4353

4454
@classmethod
4555
def from_config(cls, config: dict) -> "MockEvaluationPipeline":
@@ -102,6 +112,66 @@ async def test_run_evaluation(
102112
assert all("test_model_" in r.processed_output for r in results.results)
103113

104114

115+
@pytest.mark.parametrize(
116+
("parallelize_batches", "expected_results", "expected_accuracy"),
117+
[(False, 4, 0.5), (True, 4, 0.5)],
118+
)
119+
async def test_run_evaluation_with_parallel_batches(
120+
parallelize_batches: bool,
121+
expected_results: int,
122+
expected_accuracy: float,
123+
) -> None:
124+
target = MockEvaluationTarget(model_name="parallel_test_model")
125+
pipeline = MockEvaluationPipeline(target)
126+
dataloader = MockDataLoader()
127+
metrics = MetricSet(*[MockMetric()])
128+
evaluator = Evaluator(batch_size=2, parallelize_batches=parallelize_batches)
129+
130+
results = await evaluator.compute(
131+
pipeline=pipeline,
132+
dataloader=dataloader,
133+
metricset=metrics,
134+
)
135+
136+
assert len(results.results) == expected_results
137+
assert len(results.errors) == 0
138+
assert results.metrics["accuracy"] == expected_accuracy
139+
assert all("parallel_test_model_" in r.processed_output for r in results.results)
140+
141+
142+
async def test_parallel_batches_performance() -> None:
143+
"""Test that parallel processing is faster than sequential processing."""
144+
target = MockEvaluationTarget(model_name="timing_test_model")
145+
pipeline = MockEvaluationPipeline(target, slow=True)
146+
dataloader = MockDataLoader(dataset_size=4)
147+
metrics = MetricSet(*[MockMetric()])
148+
149+
# Test sequential processing
150+
evaluator_sequential = Evaluator(batch_size=2, parallelize_batches=False)
151+
start_time = time.perf_counter()
152+
results_sequential = await evaluator_sequential.compute(
153+
pipeline=pipeline,
154+
dataloader=dataloader,
155+
metricset=metrics,
156+
)
157+
sequential_time = time.perf_counter() - start_time
158+
159+
evaluator_parallel = Evaluator(batch_size=2, parallelize_batches=True)
160+
start_time = time.perf_counter()
161+
results_parallel = await evaluator_parallel.compute(
162+
pipeline=pipeline,
163+
dataloader=dataloader,
164+
metricset=metrics,
165+
)
166+
parallel_time = time.perf_counter() - start_time
167+
168+
assert len(results_sequential.results) == len(results_parallel.results)
169+
assert results_sequential.metrics == results_parallel.metrics
170+
171+
# Parallel processing should be roughly 2x faster, but we add some margin
172+
assert parallel_time < sequential_time * 0.7
173+
174+
105175
async def test_run_from_config() -> None:
106176
config = {
107177
"evaluation": {

0 commit comments

Comments
 (0)