Skip to content

Commit d8dde2e

Browse files
authored
[Performance] Fix queuing in llm wrappers (#3125)
1 parent 9eb009c commit d8dde2e

File tree

4 files changed

+169
-135
lines changed

4 files changed

+169
-135
lines changed

test/llm/test_wrapper.py

Lines changed: 79 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_has_transformers = importlib.util.find_spec("transformers") is not None
3939
_has_vllm = importlib.util.find_spec("vllm") is not None
4040
_has_datasets = importlib.util.find_spec("datasets") is not None
41+
_has_ray = importlib.util.find_spec("ray") is not None
4142

4243
TransformersWrapperMaxTokens = partial(
4344
TransformersWrapper, generate_kwargs={"max_new_tokens": 10, "do_sample": True}
@@ -2508,94 +2509,6 @@ def test_batching_min_batch_size_one_immediate_processing(
25082509
finally:
25092510
pool.shutdown(wait=False, cancel_futures=True)
25102511

2511-
@pytest.mark.parametrize(
2512-
"wrapper_class",
2513-
[vLLMWrapper, TransformersWrapperMaxTokens],
2514-
ids=["vllm", "transformers"],
2515-
)
2516-
def test_batching_continuous_throughput(
2517-
self,
2518-
wrapper_class,
2519-
vllm_instance,
2520-
transformers_instance,
2521-
monkey_patch_forward_for_instrumentation,
2522-
):
2523-
"""Test that the wrapper stays busy with continuous requests."""
2524-
import time
2525-
from concurrent.futures import ThreadPoolExecutor, wait
2526-
2527-
# Create wrapper using helper function
2528-
wrapper = create_batching_test_wrapper(
2529-
wrapper_class,
2530-
vllm_instance,
2531-
transformers_instance,
2532-
min_batch_size=1,
2533-
max_batch_size=2, # Small batch size to maximize throughput
2534-
batching_timeout=5.0,
2535-
)
2536-
2537-
# Monkey patch the forward method using fixture
2538-
processing_events = monkey_patch_forward_for_instrumentation[
2539-
"processing_events"
2540-
]
2541-
2542-
# Submit continuous requests
2543-
futures = []
2544-
pool = ThreadPoolExecutor(max_workers=5)
2545-
try:
2546-
# Submit requests rapidly
2547-
for i in range(10):
2548-
input_td = TensorDict(
2549-
text=Text(prompt=[f"Continuous request {i}"]), batch_size=(1,)
2550-
)
2551-
future = pool.submit(wrapper.instrumented_forward, input_td)
2552-
futures.append(future)
2553-
time.sleep(0.02) # Small delay between submissions
2554-
2555-
# Wait for all futures to complete
2556-
wait(futures, timeout=30)
2557-
2558-
# Verify all futures completed successfully
2559-
for future in futures:
2560-
result = future.result(timeout=5)
2561-
assert "text" in result
2562-
2563-
# Analyze processing patterns
2564-
assert len(processing_events) > 0, "No processing occurred"
2565-
2566-
# Check that processing happened across multiple threads (indicating concurrent processing)
2567-
thread_ids = {event["thread_id"] for event in processing_events} # noqa
2568-
assert (
2569-
len(thread_ids) > 1
2570-
), f"All processing happened in single thread: {thread_ids}"
2571-
2572-
# Check that we have multiple processing events (indicating continuous activity)
2573-
assert (
2574-
len(processing_events) >= 5
2575-
), f"Too few processing events: {len(processing_events)}"
2576-
2577-
# Check that batches were formed (some batch sizes > 1)
2578-
batch_sizes = [event["batch_size"] for event in processing_events]
2579-
assert any(
2580-
bs > 1 for bs in batch_sizes
2581-
), f"No batching occurred: {batch_sizes}"
2582-
2583-
# Check processing timing - should be relatively continuous
2584-
timestamps = [event["timestamp"] for event in processing_events]
2585-
time_diffs = [
2586-
timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)
2587-
]
2588-
2589-
# Most time differences should be small (indicating continuous processing)
2590-
small_diffs = [diff for diff in time_diffs if diff < 1.0]
2591-
assert (
2592-
len(small_diffs) >= len(time_diffs) * 0.7
2593-
), f"Too many large gaps in processing: {time_diffs}"
2594-
finally:
2595-
pool.shutdown(wait=False, cancel_futures=True)
2596-
del wrapper
2597-
gc.collect()
2598-
25992512
@pytest.mark.parametrize(
26002513
"wrapper_class",
26012514
[vLLMWrapper, TransformersWrapperMaxTokens],
@@ -2921,6 +2834,84 @@ def test_ray_wrapper(self, sample_text, backend):
29212834
gc.collect()
29222835

29232836

2837+
@pytest.mark.skipif(not _has_ray, reason="Ray not available")
2838+
class TestActorSharing:
2839+
"""Test actor sharing functionality for Remote wrappers."""
2840+
2841+
@pytest.mark.parametrize("backend", ["transformers", "vllm"])
2842+
def test_actor_sharing(self, backend):
2843+
"""Test that creating the same wrapper twice uses the same actor."""
2844+
import ray
2845+
from torchrl.modules.llm.policies import (
2846+
RemoteTransformersWrapper,
2847+
RemotevLLMWrapper,
2848+
)
2849+
2850+
# Initialize Ray if not already done
2851+
if not ray.is_initialized():
2852+
ray.init()
2853+
2854+
# Choose the wrapper class based on backend
2855+
if backend == "vllm":
2856+
if not _has_vllm:
2857+
pytest.skip("vllm not available")
2858+
WrapperClass = RemotevLLMWrapper
2859+
elif backend == "transformers":
2860+
if not _has_transformers:
2861+
pytest.skip("transformers not available")
2862+
WrapperClass = RemoteTransformersWrapper
2863+
else:
2864+
raise ValueError(f"Invalid backend: {backend}")
2865+
2866+
try:
2867+
# Create first wrapper with explicit actor name
2868+
wrapper1 = WrapperClass(
2869+
model="Qwen/Qwen2.5-0.5B",
2870+
generate=True,
2871+
input_mode="text",
2872+
generate_kwargs={"max_new_tokens": 5},
2873+
actor_name="test_shared_actor",
2874+
)
2875+
2876+
# Create second wrapper with same actor name
2877+
wrapper2 = WrapperClass(
2878+
model="Qwen/Qwen2.5-0.5B",
2879+
generate=True,
2880+
input_mode="text",
2881+
generate_kwargs={"max_new_tokens": 5},
2882+
actor_name="test_shared_actor",
2883+
)
2884+
2885+
# Check that both wrappers use the same actor
2886+
assert (
2887+
wrapper1._remote_wrapper == wrapper2._remote_wrapper
2888+
), f"Wrappers should share the same actor for backend {backend}"
2889+
2890+
# Test that both wrappers work
2891+
test_data = TensorDict(
2892+
text=Text(prompt="Hello, how are you?"),
2893+
batch_size=(),
2894+
)
2895+
2896+
result1 = wrapper1(test_data)
2897+
result2 = wrapper2(test_data)
2898+
2899+
# Both should produce valid results
2900+
assert "text" in result1
2901+
assert "text" in result2
2902+
assert isinstance(result1["text"].response, str)
2903+
assert isinstance(result2["text"].response, str)
2904+
2905+
finally:
2906+
# Cleanup
2907+
try:
2908+
del wrapper1
2909+
del wrapper2
2910+
gc.collect()
2911+
except Exception:
2912+
pass
2913+
2914+
29242915
if __name__ == "__main__":
29252916
args, unknown = argparse.ArgumentParser().parse_known_args()
29262917
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/modules/llm/policies/common.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,11 +1300,11 @@ def _extract_responses_from_full_histories(
13001300
def _batching(func):
13011301
@wraps(func)
13021302
def _batched_func(self, td_input: TensorDictBase, **kwargs):
1303-
# -- 0. skip if batching is disabled
1303+
# -- 0. Bypass if batching disabled
13041304
if not self.batching:
13051305
return func(self, td_input, **kwargs)
13061306

1307-
# ── 1. Normalise input ──────────────────────────────────────────────────
1307+
# -- 1. Normalise --------------------------------------------------------
13081308
if td_input.batch_dims > 1:
13091309
raise ValueError(
13101310
f"Batching not supported for batch_dims > 1: {td_input.batch_dims}"
@@ -1313,52 +1313,59 @@ def _batched_func(self, td_input: TensorDictBase, **kwargs):
13131313
single = td_input.batch_dims == 0
13141314
inputs = [td_input] if single else list(td_input.unbind(0))
13151315
futures = [Future() for _ in inputs]
1316+
pending = set(futures) # ← track our own Futures
13161317

1317-
# ── 2. Enqueue work and, if first in, do the draining ───────────────────
1318+
# -- 2. Enqueue ----------------------------------------------------------
13181319
self._batch_queue.extend(inputs)
13191320
self._futures.extend(futures)
13201321

13211322
min_bs = getattr(self, "_min_batch_size", 1)
13221323
max_bs = getattr(self, "_max_batch_size", None)
13231324

1325+
# -- 3. Drain while holding the lock ------------------------------------
13241326
with self._batching_lock:
1325-
# Only the thread that managed to grab the lock will run the loop
1326-
while len(self._batch_queue) >= min_bs:
1327-
# Determine slice
1328-
slice_size = (
1329-
len(self._batch_queue)
1330-
if max_bs is None
1331-
else min(max_bs, len(self._batch_queue))
1332-
)
1333-
batch = self._batch_queue[:slice_size]
1334-
fut_slice = self._futures[:slice_size]
1327+
if all(f.done() for f in futures):
1328+
# Our items were already processed by another thread.
1329+
# Skip draining; other workers will handle the rest of the queue.
1330+
pass
1331+
else:
1332+
while len(self._batch_queue) >= min_bs:
1333+
slice_size = (
1334+
len(self._batch_queue)
1335+
if max_bs is None
1336+
else min(max_bs, len(self._batch_queue))
1337+
)
1338+
batch = self._batch_queue[:slice_size]
1339+
fut_slice = self._futures[:slice_size]
1340+
1341+
try:
1342+
results = func(self, lazy_stack(batch), **kwargs).unbind(0)
1343+
if len(results) != slice_size:
1344+
raise RuntimeError(
1345+
f"Expected {slice_size} results, got {len(results)}"
1346+
)
1347+
for fut, res in zip(fut_slice, results):
1348+
fut.set_result(res)
1349+
pending.discard(fut) # ← mark as done
1350+
except Exception as exc:
1351+
for fut in fut_slice:
1352+
fut.set_exception(exc)
1353+
pending.discard(fut)
1354+
raise
13351355

1336-
# Execute model
1337-
try:
1338-
results = func(self, lazy_stack(batch), **kwargs).unbind(0)
1339-
if len(results) != slice_size: # sanity
1340-
raise RuntimeError(
1341-
f"Expected {slice_size} results, got {len(results)}"
1342-
)
1343-
# Fulfil the corresponding futures
1344-
for fut, res in zip(fut_slice, results):
1345-
fut.set_result(res)
1346-
except Exception as exc:
1347-
for fut in fut_slice:
1348-
fut.set_exception(exc)
1349-
# Propagate to caller; other waiters will read the exception from their future
1350-
raise
1351-
1352-
# Pop processed work
1353-
del self._batch_queue[:slice_size]
1354-
del self._futures[:slice_size]
1355-
1356-
# ── 3. Outside the lock: wait only for OUR futures (they may already be done) ──
1357-
wait(
1358-
futures
1359-
) # no timeout → immediate return if set_result()/set_exception() already called
1360-
result = [f.result() for f in futures]
1361-
1362-
return result[0] if single else lazy_stack(result)
1356+
# Pop processed work
1357+
del self._batch_queue[:slice_size]
1358+
del self._futures[:slice_size]
1359+
1360+
# ---- Early-exit: all *our* Futures are done -------------------
1361+
if not pending:
1362+
break
1363+
1364+
# -- 4. Outside the lock: wait only on remaining (rare) -----------------
1365+
if pending: # usually empty; safety for min_bs > queue size
1366+
wait(pending)
1367+
results = [f.result() for f in futures]
1368+
1369+
return results[0] if single else lazy_stack(results)
13631370

13641371
return _batched_func

torchrl/modules/llm/policies/transformers_wrapper.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensordict.utils import _zip_strict, NestedKey
2424
from torch import distributions as D
2525
from torch.nn.utils.rnn import pad_sequence
26-
26+
from torchrl import logger as torchrl_logger
2727
from torchrl.modules.llm.policies.common import (
2828
_batching,
2929
_extract_responses_from_full_histories,
@@ -2443,7 +2443,12 @@ class RemoteTransformersWrapper:
24432443
"""
24442444

24452445
def __init__(
2446-
self, model, max_concurrency: int = 16, validate_model: bool = True, **kwargs
2446+
self,
2447+
model,
2448+
max_concurrency: int = 16,
2449+
validate_model: bool = True,
2450+
actor_name: str = None,
2451+
**kwargs,
24472452
):
24482453
import ray
24492454

@@ -2458,10 +2463,23 @@ def __init__(
24582463

24592464
if not ray.is_initialized():
24602465
ray.init()
2461-
# Create the remote actor
2466+
2467+
if actor_name is not None:
2468+
# Check if an actor with this name already exists
2469+
try:
2470+
existing_actor = ray.get_actor(actor_name)
2471+
# If we can get the actor, assume it's alive and use it
2472+
self._remote_wrapper = existing_actor
2473+
torchrl_logger.info(f"Using existing actor {actor_name}")
2474+
return
2475+
except ValueError:
2476+
# Actor doesn't exist, create a new one
2477+
torchrl_logger.info(f"Creating new actor {actor_name}")
2478+
2479+
# Create the remote actor with the unique name
24622480
self._remote_wrapper = (
24632481
ray.remote(TransformersWrapper)
2464-
.options(max_concurrency=max_concurrency)
2482+
.options(max_concurrency=max_concurrency, name=actor_name)
24652483
.remote(model, **kwargs)
24662484
)
24672485

torchrl/modules/llm/policies/vllm_wrapper.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensordict.utils import _zip_strict, NestedKey
2424
from torch import distributions as D
2525
from torch.nn.utils.rnn import pad_sequence
26+
from torchrl import logger as torchrl_logger
2627

2728
from torchrl.envs.utils import _classproperty
2829
from torchrl.modules.llm.policies.common import (
@@ -2101,7 +2102,12 @@ class RemotevLLMWrapper:
21012102
"""
21022103

21032104
def __init__(
2104-
self, model, max_concurrency: int = 16, validate_model: bool = True, **kwargs
2105+
self,
2106+
model,
2107+
max_concurrency: int = 16,
2108+
validate_model: bool = True,
2109+
actor_name: str = None,
2110+
**kwargs,
21052111
):
21062112
import ray
21072113

@@ -2141,10 +2147,22 @@ def __init__(
21412147
if not ray.is_initialized():
21422148
ray.init()
21432149

2144-
# Create the remote actor
2150+
if actor_name is not None:
2151+
# Check if an actor with this name already exists
2152+
try:
2153+
existing_actor = ray.get_actor(actor_name)
2154+
torchrl_logger.info(f"Using existing actor {actor_name}")
2155+
# If we can get the actor, assume it's alive and use it
2156+
self._remote_wrapper = existing_actor
2157+
return
2158+
except ValueError:
2159+
# Actor doesn't exist, create a new one
2160+
torchrl_logger.info(f"Creating new actor {actor_name}")
2161+
2162+
# Create the remote actor with the unique name
21452163
self._remote_wrapper = (
21462164
ray.remote(vLLMWrapper)
2147-
.options(max_concurrency=max_concurrency)
2165+
.options(max_concurrency=max_concurrency, name=actor_name)
21482166
.remote(model, **kwargs)
21492167
)
21502168

0 commit comments

Comments
 (0)