Skip to content

Commit db9d283

Browse files
authored
[Feature] batching for vllm and transformers wrappers (#3103)
1 parent 0c05bba commit db9d283

File tree

4 files changed

+315
-2
lines changed

4 files changed

+315
-2
lines changed

test/llm/test_wrapper.py

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import importlib.util
99

1010
import os
11+
import time
1112
from functools import partial
1213

1314
import pytest
@@ -72,7 +73,11 @@ def vllm_instance():
7273
assert os.environ.get("VLLM_USE_V1") == "0"
7374

7475
try:
75-
model = LLM("Qwen/Qwen2.5-0.5B")
76+
model = LLM(
77+
"Qwen/Qwen2.5-0.5B",
78+
max_num_batched_tokens=32768, # Match max_model_len
79+
max_model_len=32768,
80+
)
7681
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
7782
tokenizer.pad_token = tokenizer.eos_token
7883
return model, tokenizer
@@ -717,6 +722,179 @@ def test_generate_false_without_log_probs(
717722
return_log_probs=False,
718723
)
719724

725+
# ================================================
726+
# Batching Tests
727+
# ================================================
728+
729+
@pytest.mark.parametrize(
730+
"wrapper_class",
731+
[vLLMWrapper, TransformersWrapperMaxTokens],
732+
ids=["vllm", "transformers"],
733+
)
734+
def test_batching(self, wrapper_class, vllm_instance, transformers_instance):
735+
from concurrent.futures import ThreadPoolExecutor, wait
736+
737+
# Handle the case where vLLM is not available
738+
if wrapper_class == vLLMWrapper:
739+
try:
740+
model, tokenizer = vllm_instance
741+
except Exception as e:
742+
if "vLLM compatibility issue" in str(e):
743+
pytest.skip("vLLM not available due to compatibility issues")
744+
raise
745+
else:
746+
model, tokenizer = transformers_instance
747+
748+
wrapper = wrapper_class(
749+
model,
750+
tokenizer=tokenizer,
751+
input_mode="text",
752+
generate=True,
753+
return_log_probs=True,
754+
batch_size=4,
755+
)
756+
# Create 2 threads and send inputs
757+
inputs = [
758+
TensorDict(
759+
text=Text(prompt=[f"Question {i}?", f"Question {i+2}?"]),
760+
batch_size=(2,),
761+
)
762+
for i in range(2)
763+
]
764+
pool = ThreadPoolExecutor(max_workers=2)
765+
try:
766+
futures = [pool.submit(wrapper, input) for input in inputs]
767+
wait(futures)
768+
finally:
769+
pool.shutdown(wait=False, cancel_futures=True)
770+
771+
@pytest.mark.parametrize(
772+
"wrapper_class",
773+
[vLLMWrapper, TransformersWrapperMaxTokens],
774+
ids=["vllm", "transformers"],
775+
)
776+
def test_batching_uneven(self, wrapper_class, vllm_instance, transformers_instance):
777+
from concurrent.futures import ThreadPoolExecutor, wait
778+
779+
if wrapper_class == vLLMWrapper:
780+
model, tokenizer = vllm_instance
781+
else:
782+
model, tokenizer = transformers_instance
783+
wrapper = wrapper_class(
784+
model,
785+
tokenizer=tokenizer,
786+
input_mode="text",
787+
generate=True,
788+
return_log_probs=True,
789+
batch_size=5,
790+
batching_timeout=5, # Increased timeout for CI environments
791+
)
792+
inputs = [
793+
TensorDict(text=Text(prompt=["Question 1?"]), batch_size=(1,)),
794+
TensorDict(
795+
text=Text(prompt=["Question 2?", "Question 3?", "Question 4?"]),
796+
batch_size=(3,),
797+
),
798+
TensorDict(
799+
text=Text(prompt=["Question 5?", "Question 6?"]), batch_size=(2,)
800+
),
801+
]
802+
pool = ThreadPoolExecutor(max_workers=3)
803+
try:
804+
futures = []
805+
for input in inputs:
806+
futures.append(pool.submit(wrapper, input))
807+
time.sleep(0.05) # Increased delay for more reliable timing
808+
809+
# Wait for first two futures with longer timeout
810+
wait(futures[:2], timeout=3)
811+
812+
# Check results with more flexible assertions
813+
result0 = futures[0].result()
814+
result1 = futures[1].result()
815+
816+
assert result0["text"].prompt == ["Question 1?"]
817+
assert result1["text"].prompt == [
818+
"Question 2?",
819+
"Question 3?",
820+
"Question 4?",
821+
]
822+
823+
# The third future may or may not be done depending on timing
824+
# Wait for it with a reasonable timeout
825+
wait(futures[2:], timeout=10)
826+
if not futures[2].done():
827+
raise RuntimeError("Third future not done")
828+
result2 = futures[2].result()
829+
assert result2["text"].prompt == ["Question 5?", "Question 6?"]
830+
finally:
831+
pool.shutdown(wait=False, cancel_futures=True)
832+
833+
@pytest.mark.parametrize(
834+
"wrapper_class",
835+
[vLLMWrapper, TransformersWrapperMaxTokens],
836+
ids=["vllm", "transformers"],
837+
)
838+
def test_batching_cleanup(
839+
self, wrapper_class, vllm_instance, transformers_instance
840+
):
841+
"""Test batching cleanup functionality."""
842+
if wrapper_class == vLLMWrapper:
843+
model, tokenizer = vllm_instance
844+
else:
845+
model, tokenizer = transformers_instance
846+
847+
wrapper = wrapper_class(
848+
model,
849+
tokenizer=tokenizer,
850+
input_mode="text",
851+
generate=True,
852+
return_log_probs=True,
853+
batch_size=3,
854+
)
855+
856+
# Check initial state
857+
state = wrapper.get_batching_state()
858+
assert state["batching_enabled"] is True
859+
assert state["batch_size"] == 3
860+
assert state["queue_size"] == 0
861+
assert state["pending_futures"] == 0
862+
863+
# Add some inputs to the queue
864+
input1 = TensorDict(text=Text(prompt=["Test 1"]), batch_size=(1,))
865+
input2 = TensorDict(text=Text(prompt=["Test 2"]), batch_size=(1,))
866+
867+
# Submit inputs (they won't be processed immediately due to batch size)
868+
from concurrent.futures import ThreadPoolExecutor
869+
870+
pool = ThreadPoolExecutor(max_workers=1)
871+
try:
872+
future1 = pool.submit(wrapper, input1)
873+
future2 = pool.submit(wrapper, input2)
874+
875+
# Check state after adding inputs
876+
state = wrapper.get_batching_state()
877+
assert state["queue_size"] >= 0 # May be 0 if processed immediately
878+
assert state["pending_futures"] >= 0
879+
880+
# Clean up
881+
wrapper.cleanup_batching()
882+
883+
# Check state after cleanup
884+
state = wrapper.get_batching_state()
885+
assert state["queue_size"] == 0
886+
assert state["pending_futures"] == 0
887+
888+
# Wait for futures to complete or fail
889+
try:
890+
future1.result(timeout=5)
891+
future2.result(timeout=5)
892+
except Exception:
893+
# Futures may fail after cleanup, which is expected
894+
pass
895+
finally:
896+
pool.shutdown(wait=False, cancel_futures=True)
897+
720898
# ================================================
721899
# Batch Size Tests
722900
# ================================================

torchrl/modules/llm/policies/common.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import warnings
88
import weakref
9+
10+
from functools import wraps
911
from typing import Any, Literal, overload
1012

1113
import torch
@@ -372,6 +374,12 @@ class LLMWrapperBase(TensorDictModuleBase):
372374
text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
373375
tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
374376
masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
377+
batch_size (int | None, optional): The batch size to use for batching. If None, no batching is done. If provided, the module will batch the inputs and process them in batches of this size.
378+
This means that a single call to the module will wait until enough inputs are available to form a batch of this size, and then process the batch.
379+
This functionality uses concurrent futures to process the batches in parallel and therefore is best used in a multi-threaded environment.
380+
Defaults to `None`.
381+
batching_timeout (float, optional): The timeout for batching. If the batch isn't completed after `batching_timeout` seconds, the batch is processed as is.
382+
Defaults to `10` seconds.
375383
376384
Attributes:
377385
collector: The collector associated with the module, if it exists.
@@ -393,6 +401,7 @@ class LLMWrapperBase(TensorDictModuleBase):
393401
device: torch.device | None
394402
layout: torch.layout | None
395403
num_samples: int | None
404+
_batching_timeout: float | None
396405

397406
@overload
398407
def __init__(
@@ -419,6 +428,8 @@ def __init__(
419428
tokens_key: NestedKey | None = "tokens",
420429
masks_key: NestedKey | None = "masks",
421430
log_probs_key: NestedKey | None = "log_probs",
431+
batch_size: int | None = None,
432+
batching_timeout: float = 10.0,
422433
):
423434
...
424435

@@ -907,6 +918,35 @@ def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase:
907918
return data.get((self.log_prob_key, "response"), **get_kwargs)
908919
raise RuntimeError("log_prob not callable when generate=True.")
909920

921+
def cleanup_batching(self):
922+
"""Clear batching queues to prevent memory leaks.
923+
924+
This method should be called when the wrapper is no longer needed
925+
or when you want to reset the batching state.
926+
"""
927+
if hasattr(self, "_batch_queue"):
928+
self._batch_queue.clear()
929+
if hasattr(self, "_futures"):
930+
self._futures.clear()
931+
932+
def get_batching_state(self):
933+
"""Get the current batching state for debugging and monitoring.
934+
935+
Returns:
936+
dict: A dictionary containing the current batching state including
937+
queue size, number of pending futures, and batch size.
938+
"""
939+
if not hasattr(self, "batch_size") or self.batch_size is None:
940+
return {"batching_enabled": False}
941+
942+
return {
943+
"batching_enabled": True,
944+
"batch_size": self.batch_size,
945+
"queue_size": len(getattr(self, "_batch_queue", [])),
946+
"pending_futures": len(getattr(self, "_futures", [])),
947+
"timeout": getattr(self, "_batching_timeout", None),
948+
}
949+
910950

911951
def _extract_responses_from_full_histories(
912952
text_full: list[str],
@@ -973,3 +1013,68 @@ def _extract_responses_from_full_histories(
9731013
return torch.stack(padded_responses)
9741014

9751015
return torch.stack(response_histories)
1016+
1017+
1018+
def _batching(func):
1019+
from concurrent.futures import Future, wait
1020+
1021+
@wraps(func)
1022+
def _batched_func(self, td_input: TensorDictBase, **kwargs):
1023+
if getattr(self, "batch_size", None) is not None:
1024+
# put elements in a queue until the batch size is reached
1025+
if td_input.batch_dims == 0:
1026+
inputs = [td_input]
1027+
else:
1028+
if td_input.batch_dims > 1:
1029+
raise ValueError(
1030+
f"Batching not supported for batch_dims > 1: {td_input.batch_dims}"
1031+
)
1032+
inputs = list(td_input.unbind(0))
1033+
1034+
# Create as many futures as inputs
1035+
futures = [Future() for _ in inputs]
1036+
1037+
self._batch_queue.extend(inputs)
1038+
self._futures.extend(futures)
1039+
1040+
# Check if we have enough inputs to form a complete batch
1041+
if len(self._batch_queue) >= self.batch_size:
1042+
# Process full batch immediately
1043+
try:
1044+
batch = lazy_stack(self._batch_queue[: self.batch_size])
1045+
results = func(self, batch, **kwargs)
1046+
batch_results = results.unbind(0)
1047+
for i, future in enumerate(self._futures[: self.batch_size]):
1048+
future.set_result(batch_results[i])
1049+
self._batch_queue = self._batch_queue[self.batch_size :]
1050+
self._futures = self._futures[self.batch_size :]
1051+
except Exception as e:
1052+
# Set exception for all futures in this batch
1053+
for future in futures:
1054+
future.set_exception(e)
1055+
raise
1056+
1057+
# Now wait for the current futures to complete (with timeout if needed)
1058+
_, not_done = wait(futures, timeout=self._batching_timeout)
1059+
1060+
# if there are still futures not done, process them as is
1061+
if not_done:
1062+
try:
1063+
inputs_not_done = [
1064+
inputs[futures.index(future)] for future in not_done
1065+
]
1066+
results = func(self, torch.stack(inputs_not_done), **kwargs).unbind(
1067+
0
1068+
)
1069+
for i, future in enumerate(not_done):
1070+
future.set_result(results[i])
1071+
except Exception as e:
1072+
# Set exception for remaining futures
1073+
for future in not_done:
1074+
future.set_exception(e)
1075+
raise
1076+
1077+
return lazy_stack([future.result() for future in futures])
1078+
return func(self, td_input, **kwargs)
1079+
1080+
return _batched_func

torchrl/modules/llm/policies/transformers_wrapper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.nn.utils.rnn import pad_sequence
2626

2727
from torchrl.modules.llm.policies.common import (
28+
_batching,
2829
_extract_responses_from_full_histories,
2930
ChatHistory,
3031
LLMWrapperBase,
@@ -99,6 +100,12 @@ class TransformersWrapper(LLMWrapperBase):
99100
tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
100101
masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
101102
history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`.
103+
batch_size (int | None, optional): The batch size to use for batching. If None, no batching is done. If provided, the module will batch the inputs and process them in batches of this size.
104+
This means that a single call to the module will wait until enough inputs are available to form a batch of this size, and then process the batch.
105+
This functionality uses concurrent futures to process the batches in parallel and therefore is best used in a multi-threaded environment.
106+
Defaults to `None`.
107+
batching_timeout (float, optional): The timeout for batching. If the batch isn't completed after `batching_timeout` seconds, the batch is processed as is.
108+
Defaults to `10` seconds.
102109
103110
Input Keys:
104111
The input key depends on both `input_mode` and `generate`:
@@ -188,9 +195,16 @@ def __init__(
188195
tokens_key: NestedKey | None = "tokens",
189196
masks_key: NestedKey | None = "masks",
190197
log_probs_key: NestedKey | None = "log_probs",
198+
batch_size: int | None = None,
199+
batching_timeout: float = 10.0,
191200
):
192201
super().__init__()
193202

203+
self.batch_size = batch_size
204+
self._batching_timeout = batching_timeout
205+
self._batch_queue = []
206+
self._futures = []
207+
194208
if isinstance(model, str):
195209
from transformers import AutoModelForCausalLM
196210

@@ -489,6 +503,7 @@ def get_new_version(self, **kwargs):
489503
return type(self)(**constructor_kwargs)
490504

491505
@set_list_to_stack(True)
506+
@_batching
492507
def forward(
493508
self,
494509
tensordict: TensorDictBase,

0 commit comments

Comments
 (0)