Skip to content

Commit 5d60871

Browse files
committed
feat: add BatchProcessor utility class
1 parent 05399b6 commit 5d60871

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

src/gradient/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
validate_client_instance as validate_client_instance,
3535
ResponseCache as ResponseCache,
3636
RateLimiter as RateLimiter,
37+
BatchProcessor as BatchProcessor,
3738
)
3839
from ._compat import (
3940
get_args as get_args,

src/gradient/_utils/_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,74 @@ def wait_time(self, tokens: int = 1) -> float:
544544
return needed / self.refill_rate
545545

546546

547+
# Batch Processing Classes
548+
class BatchProcessor:
549+
"""Utility for batching multiple requests with timeout and size limits."""
550+
551+
def __init__(self, batch_size: int = 10, timeout_seconds: float = 5.0) -> None:
552+
"""Initialize batch processor.
553+
554+
Args:
555+
batch_size: Maximum items per batch
556+
timeout_seconds: Maximum time to wait before processing batch
557+
"""
558+
self.batch_size: int = batch_size
559+
self.timeout_seconds: float = timeout_seconds
560+
self._batch: list[Any] = []
561+
self._last_add_time: float = self._now()
562+
self._callback: Callable[[list[Any]], Any] | None = None
563+
564+
def _now(self) -> float:
565+
"""Get current time in seconds."""
566+
import time
567+
return time.time()
568+
569+
def add(self, item: Any) -> None:
570+
"""Add item to current batch."""
571+
self._batch.append(item)
572+
self._last_add_time = self._now()
573+
574+
# Auto-process if batch is full
575+
if len(self._batch) >= self.batch_size:
576+
self._process_batch()
577+
578+
def set_callback(self, callback: Callable[[list[Any]], Any]) -> None:
579+
"""Set callback function to process batches."""
580+
self._callback = callback
581+
582+
def _process_batch(self) -> Any | None:
583+
"""Process current batch if not empty."""
584+
if not self._batch or not self._callback:
585+
return None
586+
587+
batch = self._batch.copy()
588+
self._batch.clear()
589+
return self._callback(batch)
590+
591+
def force_process(self) -> Any | None:
592+
"""Force process current batch regardless of size or timeout."""
593+
return self._process_batch()
594+
595+
def check_timeout(self) -> Any | None:
596+
"""Check if batch has timed out and process if needed."""
597+
if not self._batch:
598+
return None
599+
600+
elapsed = self._now() - self._last_add_time
601+
if elapsed >= self.timeout_seconds:
602+
return self._process_batch()
603+
604+
return None
605+
606+
def size(self) -> int:
607+
"""Get current batch size."""
608+
return len(self._batch)
609+
610+
def is_empty(self) -> bool:
611+
"""Check if batch is empty."""
612+
return len(self._batch) == 0
613+
614+
547615
# API Key Validation Functions
548616
def validate_api_key(api_key: str | None) -> bool:
549617
"""Validate an API key format.

tests/test_batch_processor.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Tests for batch processing functionality."""
2+
3+
import time
4+
import pytest
5+
from gradient._utils import BatchProcessor
6+
7+
8+
class TestBatchProcessor:
9+
"""Test batch processing functionality."""
10+
11+
def test_batch_processor_basic(self):
12+
"""Test basic batch processing."""
13+
processor = BatchProcessor(batch_size=3)
14+
processed_batches = []
15+
16+
def process_batch(batch):
17+
processed_batches.append(batch)
18+
return f"processed {len(batch)} items"
19+
20+
processor.set_callback(process_batch)
21+
22+
# Add items
23+
processor.add("item1")
24+
processor.add("item2")
25+
assert processor.size() == 2
26+
27+
# Add third item to trigger auto-processing
28+
processor.add("item3")
29+
assert processor.size() == 0 # Should be cleared after processing
30+
assert len(processed_batches) == 1
31+
assert processed_batches[0] == ["item1", "item2", "item3"]
32+
33+
def test_batch_processor_timeout(self):
34+
"""Test batch processing with timeout."""
35+
processor = BatchProcessor(batch_size=10, timeout_seconds=0.1)
36+
processed_batches = []
37+
38+
def process_batch(batch):
39+
processed_batches.append(batch)
40+
41+
processor.set_callback(process_batch)
42+
43+
# Add item and wait for timeout
44+
processor.add("item1")
45+
time.sleep(0.2)
46+
47+
# Check timeout should process batch
48+
processor.check_timeout()
49+
assert len(processed_batches) == 1
50+
assert processed_batches[0] == ["item1"]
51+
52+
def test_batch_processor_force_process(self):
53+
"""Test force processing of batch."""
54+
processor = BatchProcessor(batch_size=10)
55+
processed_batches = []
56+
57+
def process_batch(batch):
58+
processed_batches.append(batch)
59+
60+
processor.set_callback(process_batch)
61+
62+
# Add items without reaching batch size
63+
processor.add("item1")
64+
processor.add("item2")
65+
assert processor.size() == 2
66+
67+
# Force process
68+
processor.force_process()
69+
assert processor.size() == 0
70+
assert len(processed_batches) == 1
71+
assert processed_batches[0] == ["item1", "item2"]
72+
73+
def test_batch_processor_multiple_batches(self):
74+
"""Test processing multiple batches."""
75+
processor = BatchProcessor(batch_size=2)
76+
processed_batches = []
77+
78+
def process_batch(batch):
79+
processed_batches.append(batch)
80+
81+
processor.set_callback(process_batch)
82+
83+
# Add items to create multiple batches
84+
processor.add("item1")
85+
processor.add("item2") # Triggers first batch
86+
processor.add("item3")
87+
processor.add("item4") # Triggers second batch
88+
89+
assert len(processed_batches) == 2
90+
assert processed_batches[0] == ["item1", "item2"]
91+
assert processed_batches[1] == ["item3", "item4"]

0 commit comments

Comments
 (0)