Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ldp/alg/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,16 @@ async def _sample_trajectories_from_envs(
self.traj_buffer.clear()

traj_ids = [uuid.uuid4().hex for _ in range(len(environments))]
await asyncio.gather(
from tqdm.asyncio import tqdm_asyncio
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we import this at the module level?


await tqdm_asyncio.gather(
*(
self._rollout(*args, max_steps=max_steps)
for args in zip(traj_ids, environments, strict=True)
)
),
desc="Sampling trajectories"
)

return [self.traj_buffer[traj_id] for traj_id in traj_ids]

async def _rollout(
Expand Down
199 changes: 199 additions & 0 deletions ldp/graph/async_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,205 @@ async def _maybe_process_batch(self):
@abstractmethod
async def _batched_call(self, batch_kwargs: dict[str, Any]):
"""Logic to call the worker on a batch of inputs."""



class AsyncBufferedWorker(ABC):
"""Abstract class for a worker that buffers inputs and processes them in batches."""

def __init__(
self,
batch_size: int,
max_wait_interval: float,
collate_fn: Callable = lambda x: x,
decollate_fn: Callable = lambda x: x,
):
"""Initialize.

Args:
batch_size: The target batch size to use when calling the module. As soon as
batch_size calls are made, a forward pass is executed.
max_wait_interval: The maximum time to wait for a batch to fill up before
executing the calls we have buffered.
collate_fn: A function to pre-process a list of inputs into a batch. Defaults to a
no-op.
decollate_fn: Kind of like the opposite of collate_fn. This function should take
the batched output and return an ordered list of outputs. Defaults to no-op.
"""
self.batch_size = batch_size
self.timeout = max_wait_interval
self.collate_fn = collate_fn
self.decollate_fn = decollate_fn

self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = []
self._result_buffer: dict[UUID, Any] = {}
self._lock = asyncio.Lock()
self._batch_ready_event = asyncio.Event()
self._processed_events = {}
self._counter = 0
self._events_count = {}

async def __call__(self, **kwargs):
request_id = uuid4()
request_ts = time.time()

async with self._lock:
self._processed_events[request_id] = asyncio.Event()
self._events_count[request_id] = self._counter
self._counter += 1
print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use logging over print

self._work_buffer.append((request_ts, request_id, kwargs))

# If we've reached batch size, we trigger the processing event immediately
if len(self._work_buffer) >= self.batch_size:
self._batch_ready_event.set()

try:
# Wait for either the batch to fill up or the timeout to expire
await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
pass
Comment on lines +206 to +210
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        async with asyncio.timeout(self.timeout):
            # Wait for either the batch to fill up or the timeout to expire
            self._batch_ready_event.wait()

Alternate way of doing this using built-in asyncio.timeout

I am not sure if you need to await the wait()


await self._maybe_process_batch()

await self._processed_events[request_id].wait()

async with self._lock:
print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use logging not print

self._events_count.pop(request_id)
self._processed_events.pop(request_id)
return self._result_buffer.pop(request_id)

async def _maybe_process_batch(self):
"""If the buffer is >= batch size or we have been waiting long enough, process the old batch.

If neither condition is met, do nothing.
"""
async with self._lock:
# If there's at least one request in the buffer, we can process it
if len(self._work_buffer) == 0:
return

self._work_buffer.sort(key=operator.itemgetter(0))

batch = self._work_buffer[: self.batch_size]
self._work_buffer = self._work_buffer[self.batch_size :]

if len(self._work_buffer) < self.batch_size:
self._batch_ready_event.clear()

# Construct the batch tensors
sample_kwargs = [x[2] for x in batch]
batch_kwargs = self.collate_fn(sample_kwargs)

print(f"starting to wait for batched call, counter: {self._counter}")
batched_results = await self._batched_call(batch_kwargs)
print(f"finished waiting for batched call, counter: {self._counter}")
request_ids = [x[1] for x in batch]
results = self.decollate_fn(batched_results)
async with self._lock:
print(f"updating result buffer, counter: {self._counter}")
self._result_buffer.update(zip(request_ids, results, strict=True))
for request_id in request_ids:
self._processed_events[request_id].set()

def _process_batch(self):
"""Processes the current batch."""


@abstractmethod
async def _batched_call(self, batch_kwargs: dict[str, Any]):
"""Logic to call the worker on a batch of inputs."""




class AsyncBufferedWorker(ABC):
def __init__(self, batch_size, max_wait_interval, collate_fn=lambda x: x, decollate_fn=lambda x: x):
self.batch_size = batch_size
self.timeout = max_wait_interval
self.collate_fn = collate_fn
self.decollate_fn = decollate_fn

self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = []
self._result_buffer: dict[UUID, Any] = {}
self._lock = asyncio.Lock()
self._new_data_event = asyncio.Event()

self._processed_events = {}
self._counter = 0
self._events_count = {}

# Start the background batch processing task
self._batch_processing_task = asyncio.create_task(self._batch_processor())

async def __call__(self, **kwargs):
request_id = uuid4()
request_ts = time.time()

async with self._lock:
self._processed_events[request_id] = asyncio.Event()
self._events_count[request_id] = self._counter
self._counter += 1
print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}")
self._work_buffer.append((request_ts, request_id, kwargs))
if len(self._work_buffer) >= self.batch_size:
self._new_data_event.set() # Signal that new data has arrived
print(f"did set new data event, counter: {self._counter}")

# Wait for the result to be processed
await self._processed_events[request_id].wait()

async with self._lock:
print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}")
self._events_count.pop(request_id)
self._processed_events.pop(request_id)
return self._result_buffer.pop(request_id)

async def _batch_processor(self):
while True:
try:
# Wait for new data or timeout
await asyncio.wait_for(self._new_data_event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
pass

async with self._lock:
if len(self._work_buffer) == 0:
self._new_data_event.clear()
continue

now = time.time()
# Sort the work buffer by timestamp to maintain order
self._work_buffer.sort(key=operator.itemgetter(0))

batch = self._work_buffer[:self.batch_size]
self._work_buffer = self._work_buffer[self.batch_size:]
if len(self._work_buffer) == 0:
self._new_data_event.clear()

# Process the batch outside the lock
sample_kwargs = [x[2] for x in batch]
batch_kwargs = self.collate_fn(sample_kwargs)
print(f"Starting batched call, counter: {self._counter}, batch size: {len(batch)}")
batched_results = await self._batched_call(batch_kwargs)
print(f"Finished batched call, counter: {self._counter}")
request_ids = [x[1] for x in batch]
results = self.decollate_fn(batched_results)
async with self._lock:
print(f"Updating result buffer, counter: {self._counter}")
self._result_buffer.update(zip(request_ids, results))
for request_id in request_ids:
self._processed_events[request_id].set()

# Let other requests proceed as soon as their result is available
await asyncio.sleep(0.0)

@abstractmethod
async def _batched_call(self, batch_kwargs: dict[str, Any]):
"""Logic to call the worker on a batch of inputs."""
pass



class AsyncTorchModule(AsyncBufferedWorker):
Expand Down
Loading