Skip to content

Commit 0e4dabe

Browse files
authored
[Feature] Remote LLM wrappers and batching (#3116)
1 parent 744f061 commit 0e4dabe

File tree

6 files changed

+784
-14
lines changed

6 files changed

+784
-14
lines changed

docs/source/reference/llms.rst

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,107 @@ The main goal of these primitives is to:
436436
LLMWrapperBase
437437
TransformersWrapper
438438
vLLMWrapper
439+
RemoteTransformersWrapper
440+
RemotevLLMWrapper
439441
ChatHistory
440442
Text
441443
LogProbs
442444
Masks
443445
Tokens
444446

447+
Remote Wrappers
448+
^^^^^^^^^^^^^^^
449+
450+
TorchRL provides remote wrapper classes that enable distributed execution of LLM wrappers using Ray. These wrappers provide a simplified interface that doesn't require explicit `remote()` and `get()` calls, making them easy to use in distributed settings.
451+
452+
**Key Features:**
453+
454+
- **Simplified Interface**: No need to call `remote()` and `get()` explicitly
455+
- **Full API Compatibility**: Exposes all public methods from the base `LLMWrapperBase` class
456+
- **Automatic Ray Management**: Handles Ray initialization and remote execution internally
457+
- **Property Access**: All properties are accessible through the remote wrapper
458+
- **Error Handling**: Proper error propagation from remote actors
459+
- **Resource Management**: Context manager support for automatic cleanup
460+
461+
**Model Parameter Requirements:**
462+
463+
- **RemotevLLMWrapper**: Accepts string model names/paths (recommended) or remote vLLM LLM objects with ray handles. Local vLLM models are not serializable.
464+
- **RemoteTransformersWrapper**: Only accepts string model names/paths. Transformers models are not serializable.
465+
466+
**Usage Examples:**
467+
468+
.. code-block:: python
469+
470+
import ray
471+
from torchrl.modules.llm.policies import RemotevLLMWrapper, RemoteTransformersWrapper
472+
from torchrl.data.llm import History
473+
from torchrl.modules.llm.policies import ChatHistory, Text
474+
from tensordict import TensorDict
475+
476+
# Initialize Ray (if not already done)
477+
if not ray.is_initialized():
478+
ray.init()
479+
480+
# Use context manager for proper cleanup (recommended)
481+
with RemotevLLMWrapper(
482+
model="gpt2",
483+
max_concurrency=16, # Control concurrent calls
484+
input_mode="history",
485+
generate=True,
486+
generate_kwargs={"max_new_tokens": 50, "temperature": 0.7}
487+
) as remote_wrapper:
488+
489+
# Create test input
490+
history = History.from_chats([[
491+
{"role": "user", "content": "Hello, how are you?"}
492+
]])
493+
chat_history = ChatHistory(prompt=history)
494+
tensordict_input = TensorDict(history=chat_history, batch_size=(1,))
495+
496+
# Use like a regular wrapper (no remote/get calls needed!)
497+
result = remote_wrapper(tensordict_input)
498+
print(result["text"].response)
499+
500+
# Transformers wrapper (only string models supported)
501+
with RemoteTransformersWrapper(
502+
model="gpt2",
503+
max_concurrency=16,
504+
input_mode="text",
505+
generate=True,
506+
generate_kwargs={"max_new_tokens": 30}
507+
) as remote_transformers:
508+
509+
text_input = TensorDict({"text": Text(prompt="Hello world")}, batch_size=(1,))
510+
result = remote_transformers(text_input)
511+
print(result["text"].response)
512+
513+
**Cleanup and Resource Management:**
514+
515+
The remote wrappers implement context managers for proper resource cleanup:
516+
517+
.. code-block:: python
518+
519+
# Context manager (recommended)
520+
with RemotevLLMWrapper(model="gpt2") as wrapper:
521+
result = wrapper(input_data)
522+
# Cleanup is automatic when exiting the context
523+
524+
# Manual cleanup
525+
wrapper = RemotevLLMWrapper(model="gpt2")
526+
try:
527+
result = wrapper(input_data)
528+
finally:
529+
wrapper.cleanup_batching() # Important: prevents hanging
530+
531+
**Performance Considerations:**
532+
533+
- **Network Overhead**: Remote execution adds network communication overhead
534+
- **Serialization**: Data is serialized when sent to remote actors
535+
- **Memory**: Each remote actor maintains its own copy of the model
536+
- **Concurrency**: Multiple remote wrappers can run concurrently
537+
- **Max Concurrency**: Use the `max_concurrency` parameter to control the number of concurrent calls to each remote actor
538+
- **Cleanup**: Always use context managers or call `cleanup_batching()` to prevent hanging due to batching locks
539+
445540
Utils
446541
^^^^^
447542

test/llm/test_wrapper.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,6 +2850,53 @@ def __init__(self):
28502850
assert result["text"].prompt == "Single question without batch dimension?"
28512851

28522852

2853+
class TestRayWrapper:
2854+
@pytest.mark.parametrize("backend", ["transformers", "vllm"])
2855+
def test_ray_wrapper(self, sample_text, backend):
2856+
import gc
2857+
from concurrent.futures import ThreadPoolExecutor
2858+
2859+
from torchrl import logger as torchrl_logger
2860+
from torchrl.modules.llm.policies import (
2861+
RemoteTransformersWrapper,
2862+
RemotevLLMWrapper,
2863+
)
2864+
2865+
# check that the wrapper is remote
2866+
if backend == "vllm":
2867+
cls = RemotevLLMWrapper
2868+
elif backend == "transformers":
2869+
cls = RemoteTransformersWrapper
2870+
else:
2871+
raise ValueError(f"Invalid backend: {backend}")
2872+
model = cls(
2873+
model="Qwen/Qwen2.5-0.5B",
2874+
generate=True,
2875+
input_mode="text",
2876+
batching=True,
2877+
generate_kwargs={"max_new_tokens": 10},
2878+
)
2879+
try:
2880+
# check batching
2881+
data = TensorDict(
2882+
text=Text(prompt=sample_text[0]),
2883+
batch_size=(),
2884+
)
2885+
with ThreadPoolExecutor(max_workers=10) as executor:
2886+
futures = [executor.submit(model, data) for _ in range(10)]
2887+
torchrl_logger.info(f"Futures: {futures}")
2888+
results = [future.result() for future in futures]
2889+
torchrl_logger.info(f"Results: {results}")
2890+
assert all(result.batch_size == () for result in results)
2891+
assert all(
2892+
isinstance(result["text"].response, str) for result in results
2893+
)
2894+
torchrl_logger.info("Batching test passed")
2895+
finally:
2896+
del model
2897+
gc.collect()
2898+
2899+
28532900
if __name__ == "__main__":
28542901
args, unknown = argparse.ArgumentParser().parse_known_args()
28552902
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/modules/llm/policies/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from __future__ import annotations
77

88
from .common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens
9-
from .transformers_wrapper import TransformersWrapper
9+
from .transformers_wrapper import RemoteTransformersWrapper, TransformersWrapper
1010

11-
from .vllm_wrapper import vLLMWrapper
11+
from .vllm_wrapper import RemotevLLMWrapper, vLLMWrapper
1212

1313
__all__ = [
1414
"TransformersWrapper",
15+
"RemoteTransformersWrapper",
1516
"vLLMWrapper",
17+
"RemotevLLMWrapper",
1618
"LLMWrapperBase",
1719
"Text",
1820
"LogProbs",

torchrl/modules/llm/policies/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import threading
88
import warnings
99
import weakref
10-
1110
from functools import wraps
1211
from typing import Any, Literal, overload
1312

@@ -1334,6 +1333,13 @@ def _batched_func(self, td_input: TensorDictBase, **kwargs):
13341333
).unbind(0)
13351334
for i, future in enumerate(not_done):
13361335
future.set_result(results[i])
1336+
# remove not done futures from the queue
1337+
self._batch_queue = [
1338+
q
1339+
for q, f in zip(self._batch_queue, futures)
1340+
if f not in not_done
1341+
]
1342+
self._futures = [f for f in self._futures if f not in not_done]
13371343
except Exception as e:
13381344
# Set exception for remaining futures that haven't been completed yet
13391345
for future in not_done:

0 commit comments

Comments
 (0)