-
Notifications
You must be signed in to change notification settings - Fork 17
AsyncBuffer update #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
AsyncBuffer update #130
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,6 +146,205 @@ async def _maybe_process_batch(self): | |
| @abstractmethod | ||
| async def _batched_call(self, batch_kwargs: dict[str, Any]): | ||
| """Logic to call the worker on a batch of inputs.""" | ||
|
|
||
|
|
||
|
|
||
| class AsyncBufferedWorker(ABC): | ||
| """Abstract class for a worker that buffers inputs and processes them in batches.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| batch_size: int, | ||
| max_wait_interval: float, | ||
| collate_fn: Callable = lambda x: x, | ||
| decollate_fn: Callable = lambda x: x, | ||
| ): | ||
| """Initialize. | ||
|
|
||
| Args: | ||
| batch_size: The target batch size to use when calling the module. As soon as | ||
| batch_size calls are made, a forward pass is executed. | ||
| max_wait_interval: The maximum time to wait for a batch to fill up before | ||
| executing the calls we have buffered. | ||
| collate_fn: A function to pre-process a list of inputs into a batch. Defaults to a | ||
| no-op. | ||
| decollate_fn: Kind of like the opposite of collate_fn. This function should take | ||
| the batched output and return an ordered list of outputs. Defaults to no-op. | ||
| """ | ||
| self.batch_size = batch_size | ||
| self.timeout = max_wait_interval | ||
| self.collate_fn = collate_fn | ||
| self.decollate_fn = decollate_fn | ||
|
|
||
| self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] | ||
| self._result_buffer: dict[UUID, Any] = {} | ||
| self._lock = asyncio.Lock() | ||
| self._batch_ready_event = asyncio.Event() | ||
| self._processed_events = {} | ||
| self._counter = 0 | ||
| self._events_count = {} | ||
|
|
||
| async def __call__(self, **kwargs): | ||
| request_id = uuid4() | ||
| request_ts = time.time() | ||
|
|
||
| async with self._lock: | ||
| self._processed_events[request_id] = asyncio.Event() | ||
| self._events_count[request_id] = self._counter | ||
| self._counter += 1 | ||
| print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use |
||
| self._work_buffer.append((request_ts, request_id, kwargs)) | ||
|
|
||
| # If we've reached batch size, we trigger the processing event immediately | ||
| if len(self._work_buffer) >= self.batch_size: | ||
| self._batch_ready_event.set() | ||
|
|
||
| try: | ||
| # Wait for either the batch to fill up or the timeout to expire | ||
| await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout) | ||
| except asyncio.TimeoutError: | ||
| pass | ||
|
Comment on lines
+206
to
+210
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternate way of doing this using built-in I am not sure if you need to |
||
|
|
||
| await self._maybe_process_batch() | ||
|
|
||
| await self._processed_events[request_id].wait() | ||
|
|
||
| async with self._lock: | ||
| print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use |
||
| self._events_count.pop(request_id) | ||
| self._processed_events.pop(request_id) | ||
| return self._result_buffer.pop(request_id) | ||
|
|
||
| async def _maybe_process_batch(self): | ||
| """If the buffer is >= batch size or we have been waiting long enough, process the old batch. | ||
|
|
||
| If neither condition is met, do nothing. | ||
| """ | ||
| async with self._lock: | ||
| # If there's at least one request in the buffer, we can process it | ||
| if len(self._work_buffer) == 0: | ||
| return | ||
|
|
||
| self._work_buffer.sort(key=operator.itemgetter(0)) | ||
|
|
||
| batch = self._work_buffer[: self.batch_size] | ||
| self._work_buffer = self._work_buffer[self.batch_size :] | ||
|
|
||
| if len(self._work_buffer) < self.batch_size: | ||
| self._batch_ready_event.clear() | ||
|
|
||
| # Construct the batch tensors | ||
| sample_kwargs = [x[2] for x in batch] | ||
| batch_kwargs = self.collate_fn(sample_kwargs) | ||
|
|
||
| print(f"starting to wait for batched call, counter: {self._counter}") | ||
| batched_results = await self._batched_call(batch_kwargs) | ||
| print(f"finished waiting for batched call, counter: {self._counter}") | ||
| request_ids = [x[1] for x in batch] | ||
| results = self.decollate_fn(batched_results) | ||
| async with self._lock: | ||
| print(f"updating result buffer, counter: {self._counter}") | ||
| self._result_buffer.update(zip(request_ids, results, strict=True)) | ||
| for request_id in request_ids: | ||
| self._processed_events[request_id].set() | ||
|
|
||
| def _process_batch(self): | ||
| """Processes the current batch.""" | ||
|
|
||
|
|
||
| @abstractmethod | ||
| async def _batched_call(self, batch_kwargs: dict[str, Any]): | ||
| """Logic to call the worker on a batch of inputs.""" | ||
|
|
||
|
|
||
|
|
||
|
|
||
| class AsyncBufferedWorker(ABC): | ||
| def __init__(self, batch_size, max_wait_interval, collate_fn=lambda x: x, decollate_fn=lambda x: x): | ||
| self.batch_size = batch_size | ||
| self.timeout = max_wait_interval | ||
| self.collate_fn = collate_fn | ||
| self.decollate_fn = decollate_fn | ||
|
|
||
| self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] | ||
| self._result_buffer: dict[UUID, Any] = {} | ||
| self._lock = asyncio.Lock() | ||
| self._new_data_event = asyncio.Event() | ||
|
|
||
| self._processed_events = {} | ||
| self._counter = 0 | ||
| self._events_count = {} | ||
|
|
||
| # Start the background batch processing task | ||
| self._batch_processing_task = asyncio.create_task(self._batch_processor()) | ||
|
|
||
| async def __call__(self, **kwargs): | ||
| request_id = uuid4() | ||
| request_ts = time.time() | ||
|
|
||
| async with self._lock: | ||
| self._processed_events[request_id] = asyncio.Event() | ||
| self._events_count[request_id] = self._counter | ||
| self._counter += 1 | ||
| print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}") | ||
| self._work_buffer.append((request_ts, request_id, kwargs)) | ||
| if len(self._work_buffer) >= self.batch_size: | ||
| self._new_data_event.set() # Signal that new data has arrived | ||
| print(f"did set new data event, counter: {self._counter}") | ||
|
|
||
| # Wait for the result to be processed | ||
| await self._processed_events[request_id].wait() | ||
|
|
||
| async with self._lock: | ||
| print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}") | ||
| self._events_count.pop(request_id) | ||
| self._processed_events.pop(request_id) | ||
| return self._result_buffer.pop(request_id) | ||
|
|
||
| async def _batch_processor(self): | ||
| while True: | ||
| try: | ||
| # Wait for new data or timeout | ||
| await asyncio.wait_for(self._new_data_event.wait(), timeout=self.timeout) | ||
| except asyncio.TimeoutError: | ||
| pass | ||
|
|
||
| async with self._lock: | ||
| if len(self._work_buffer) == 0: | ||
| self._new_data_event.clear() | ||
| continue | ||
|
|
||
| now = time.time() | ||
| # Sort the work buffer by timestamp to maintain order | ||
| self._work_buffer.sort(key=operator.itemgetter(0)) | ||
|
|
||
| batch = self._work_buffer[:self.batch_size] | ||
| self._work_buffer = self._work_buffer[self.batch_size:] | ||
| if len(self._work_buffer) == 0: | ||
| self._new_data_event.clear() | ||
|
|
||
| # Process the batch outside the lock | ||
| sample_kwargs = [x[2] for x in batch] | ||
| batch_kwargs = self.collate_fn(sample_kwargs) | ||
| print(f"Starting batched call, counter: {self._counter}, batch size: {len(batch)}") | ||
| batched_results = await self._batched_call(batch_kwargs) | ||
| print(f"Finished batched call, counter: {self._counter}") | ||
| request_ids = [x[1] for x in batch] | ||
| results = self.decollate_fn(batched_results) | ||
| async with self._lock: | ||
| print(f"Updating result buffer, counter: {self._counter}") | ||
| self._result_buffer.update(zip(request_ids, results)) | ||
| for request_id in request_ids: | ||
| self._processed_events[request_id].set() | ||
|
|
||
| # Let other requests proceed as soon as their result is available | ||
| await asyncio.sleep(0.0) | ||
|
|
||
| @abstractmethod | ||
| async def _batched_call(self, batch_kwargs: dict[str, Any]): | ||
| """Logic to call the worker on a batch of inputs.""" | ||
| pass | ||
|
|
||
|
|
||
|
|
||
| class AsyncTorchModule(AsyncBufferedWorker): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we import this at the module level?