Skip to content

Commit 4ba5066

Browse files
author
Vincent Moens
committed
[Feature] Collectors for async envs
ghstack-source-id: 764c21d Pull Request resolved: #2893
1 parent efe9389 commit 4ba5066

File tree

6 files changed

+282
-41
lines changed

6 files changed

+282
-41
lines changed

test/test_collector.py

Lines changed: 154 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from torchrl.data.llm.dataset import _has_transformers
6868
from torchrl.data.utils import CloudpickleWrapper
6969
from torchrl.envs import (
70+
AsyncEnvPool,
7071
EnvBase,
7172
EnvCreator,
7273
InitTracker,
@@ -3737,10 +3738,12 @@ async def test_llm_collector_start(self, vllm_instance):
37373738
def test_llm_collector_completed(
37383739
self, vllm_instance_opt, rb, yield_only_last_steps
37393740
):
3741+
torch.manual_seed(0)
37403742
policy = vLLMWrapper(vllm_instance_opt)
37413743
tokenizer = vllm_instance_opt.get_tokenizer()
37423744
bsz = 4
37433745
total_steps = 20
3746+
max_steps = 20
37443747
dataloader = DummyStrDataLoader(bsz)
37453748

37463749
env = LLMEnv.from_dataloader(
@@ -3751,7 +3754,7 @@ def test_llm_collector_completed(
37513754
eos_token_id=tokenizer.eos_token_id,
37523755
)
37533756
# To make sure the env breaks at some point
3754-
env = env.append_transform(StepCounter(max_steps=100))
3757+
env = env.append_transform(StepCounter(max_steps=max_steps))
37553758

37563759
if rb:
37573760
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
@@ -3774,11 +3777,27 @@ def test_llm_collector_completed(
37743777
for data in collector:
37753778
if rb is None:
37763779
assert data.ndim == 1
3777-
assert (data["next", "step_count"] < 99).all()
3780+
# assert (data["next", "step_count"] < max_steps-1).all()
37783781
cur_total_steps += data.numel()
37793782
for i in range(data.numel()):
3780-
# Check that there are more chars in the next step
3781-
assert len(data["text"][i]) < len(data["next", "text"][i])
3783+
if data[i]["next", "step_count"] == max_steps:
3784+
continue
3785+
if data[i]["text_response"]:
3786+
# Check that there are more chars in the next step
3787+
assert len(data["text"][i]) < len(data["next", "text"][i]), (
3788+
i,
3789+
data[i]["next", "step_count"],
3790+
data[i]["next", "done"],
3791+
data[i]["text_response"],
3792+
)
3793+
else:
3794+
assert len(data["text"][i]) == len(data["next", "text"][i]), (
3795+
i,
3796+
data[i]["next", "step_count"],
3797+
data[i]["next", "done"],
3798+
data[i]["text_response"],
3799+
)
3800+
37823801
if yield_only_last_steps:
37833802
assert data.shape == (1,)
37843803
else:
@@ -3787,8 +3806,137 @@ def test_llm_collector_completed(
37873806
assert data is None
37883807
sample = rb.sample(5)
37893808
for i in range(sample.numel()):
3790-
# Check that there are more chars in the next step
3791-
assert len(sample["text"][i]) < len(sample["next", "text"][i])
3809+
if sample[i]["next", "step_count"] == max_steps:
3810+
continue
3811+
if sample[i]["text_response"]:
3812+
# Check that there are more chars in the next step
3813+
assert len(sample["text"][i]) < len(
3814+
sample["next", "text"][i]
3815+
), (
3816+
i,
3817+
sample[i]["next", "step_count"],
3818+
sample[i]["next", "done"],
3819+
sample[i]["text_response"],
3820+
)
3821+
else:
3822+
assert len(sample["text"][i]) == len(
3823+
sample["next", "text"][i]
3824+
), (
3825+
i,
3826+
sample[i]["next", "step_count"],
3827+
sample[i]["next", "done"],
3828+
sample[i]["text_response"],
3829+
)
3830+
3831+
assert sample.ndim == 1
3832+
assert sample.shape == (5,)
3833+
assert (sample["next", "step_count"] < 99).all()
3834+
cur_total_steps += 1
3835+
assert collector._frames >= cur_total_steps
3836+
if rb is None and not yield_only_last_steps:
3837+
assert has_found_one_with_more_steps
3838+
assert collector._frames >= total_steps
3839+
3840+
@pytest.mark.slow
3841+
@pytest.mark.parametrize("rb", [False, True])
3842+
@pytest.mark.parametrize("yield_only_last_steps", [False, True])
3843+
def test_llm_collector_completed_async(
3844+
self, vllm_instance_opt, rb, yield_only_last_steps
3845+
):
3846+
torch.manual_seed(0)
3847+
policy = vLLMWrapper(vllm_instance_opt)
3848+
tokenizer = vllm_instance_opt.get_tokenizer()
3849+
bsz = 4
3850+
total_steps = 20
3851+
max_steps = 20
3852+
dataloader = DummyStrDataLoader(bsz)
3853+
3854+
def env_maker():
3855+
env = LLMEnv.from_dataloader(
3856+
dataloader=dataloader,
3857+
str2str=True,
3858+
batch_size=(),
3859+
group_repeats=True,
3860+
eos_token_id=tokenizer.eos_token_id,
3861+
)
3862+
# To make sure the env breaks at some point
3863+
env = env.append_transform(StepCounter(max_steps=max_steps))
3864+
return env
3865+
3866+
env = AsyncEnvPool([env_maker] * bsz, backend="threading", stack="lazy")
3867+
3868+
if rb:
3869+
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
3870+
else:
3871+
rb = None
3872+
collector = LLMCollector(
3873+
env=env,
3874+
policy_factory=lambda: policy,
3875+
steps_per_batch=env.batch_size[0],
3876+
replay_buffer=rb,
3877+
total_steps=total_steps,
3878+
yield_completed_trajectories=True,
3879+
yield_only_last_steps=yield_only_last_steps,
3880+
)
3881+
assert collector.yield_completed_trajectories
3882+
assert collector.yield_only_last_steps is yield_only_last_steps
3883+
3884+
cur_total_steps = 0
3885+
has_found_one_with_more_steps = False
3886+
for data in collector:
3887+
if rb is None:
3888+
assert data.ndim == 1
3889+
# assert (data["next", "step_count"] < max_steps-1).all()
3890+
cur_total_steps += data.numel()
3891+
for i in range(data.numel()):
3892+
if data[i]["next", "step_count"] == max_steps:
3893+
continue
3894+
if data[i]["text_response"]:
3895+
# Check that there are more chars in the next step
3896+
assert len(data["text"][i]) < len(data["next", "text"][i]), (
3897+
i,
3898+
data[i]["next", "step_count"],
3899+
data[i]["next", "done"],
3900+
data[i]["text_response"],
3901+
)
3902+
else:
3903+
assert len(data["text"][i]) == len(data["next", "text"][i]), (
3904+
i,
3905+
data[i]["next", "step_count"],
3906+
data[i]["next", "done"],
3907+
data[i]["text_response"],
3908+
)
3909+
3910+
if yield_only_last_steps:
3911+
assert data.shape == (1,)
3912+
else:
3913+
has_found_one_with_more_steps |= data.numel() > 1
3914+
else:
3915+
assert data is None
3916+
sample = rb.sample(5)
3917+
for i in range(sample.numel()):
3918+
if sample[i]["next", "step_count"] == max_steps:
3919+
continue
3920+
if sample[i]["text_response"]:
3921+
# Check that there are more chars in the next step
3922+
assert len(sample["text"][i]) < len(
3923+
sample["next", "text"][i]
3924+
), (
3925+
i,
3926+
sample[i]["next", "step_count"],
3927+
sample[i]["next", "done"],
3928+
sample[i]["text_response"],
3929+
)
3930+
else:
3931+
assert len(sample["text"][i]) == len(
3932+
sample["next", "text"][i]
3933+
), (
3934+
i,
3935+
sample[i]["next", "step_count"],
3936+
sample[i]["next", "done"],
3937+
sample[i]["text_response"],
3938+
)
3939+
37923940
assert sample.ndim == 1
37933941
assert sample.shape == (5,)
37943942
assert (sample["next", "step_count"] < 99).all()

torchrl/collectors/llm.py

Lines changed: 91 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
WeightUpdateSenderBase,
1818
)
1919
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
20+
from torchrl.envs import AsyncEnvPool
2021
from torchrl.envs.common import EnvBase
2122

2223

@@ -57,7 +58,8 @@ class LLMCollector(SyncDataCollector):
5758
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
5859
instance.
5960
Defaults to ``None``.
60-
async_envs (bool, optional): if ``True``, the environment will be run synchronously.
61+
async_envs (bool, optional): if ``True``, the environment will be run asynchronously. Defaults to `True` if the
62+
environment is a :class:`~torchrl.envs.AsyncEnvPool` instance.
6163
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
6264
but populate the buffer instead. Defaults to ``None``.
6365
reset_at_each_iter (bool, optional): if ``True``, the environment will be reset at each iteration.
@@ -149,7 +151,7 @@ def __init__(
149151
yield_completed_trajectories: bool | None = None,
150152
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
151153
total_steps: int = -1,
152-
async_envs: bool = False,
154+
async_envs: bool | None = None,
153155
replay_buffer: ReplayBuffer | None = None,
154156
reset_at_each_iter: bool = False,
155157
flatten_data: bool | None = None,
@@ -160,8 +162,6 @@ def __init__(
160162
| Callable[[], WeightUpdateSenderBase]
161163
| None = None,
162164
):
163-
if async_envs:
164-
raise NotImplementedError
165165
super().__init__(
166166
create_env_fn=env,
167167
policy=policy,
@@ -209,6 +209,13 @@ def __init__(
209209
)
210210
self._yield_queues = [deque() for _ in range(self.env.batch_size[0])]
211211
self._trajectory_queue = deque()
212+
self.async_envs = bool(async_envs) | isinstance(self.env, AsyncEnvPool)
213+
if self.async_envs and not isinstance(self.env, AsyncEnvPool):
214+
# This basically means that `async_envs` is automatically set and passing is it useless as of today,
215+
# except for the following error.
216+
raise RuntimeError(
217+
"async_envs requires the environment to be an AsyncEnvPool instance."
218+
)
212219

213220
@property
214221
def steps_per_batch(self) -> int:
@@ -218,7 +225,10 @@ def steps_per_batch(self) -> int:
218225
@property
219226
def rollout(self) -> Callable[[], TensorDictBase]:
220227
if self.yield_completed_trajectories:
221-
return self._rollout_yield_trajs
228+
if self.async_envs:
229+
return self._rollout_yield_trajs_async
230+
else:
231+
return self._rollout_yield_trajs
222232
else:
223233
return self._rollout_all
224234

@@ -250,27 +260,33 @@ def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout
250260

251261
def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rollout
252262
if self._shuttle is None:
253-
data = self.env.reset()
263+
raise RuntimeError("Data shuttle not found")
264+
# next_output = self.env.reset()
254265
else:
255-
data = self._shuttle
266+
next_output = self._shuttle
256267

257268
collected_steps = 0
258269
dones = torch.zeros(self.env.batch_size, dtype=torch.bool)
259270
while True:
260271
if self._trajectory_queue:
261272
break
262-
policy_input = data
263-
env_input = self.policy(policy_input)
264-
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
273+
env_input = self.policy(next_output)
274+
cur_output, next_output = self.env.step_and_maybe_reset(env_input)
275+
# for i in range(cur_output.numel()):
276+
# print(len(cur_output[i]["text"]) < len(cur_output[i]["next", "text"]))
265277

266278
# carry over collector data without messing up devices
267-
collector_data = env_output.get("collector").copy()
268-
env_next_output.set("collector", collector_data)
269-
self._shuttle = env_next_output
270-
self._update_traj_ids(env_output)
271-
data = env_output
272-
collected_steps += data.numel()
273-
for i, (_data, queue) in enumerate(zip(data.unbind(0), self._yield_queues)):
279+
self._update_traj_ids(cur_output)
280+
281+
collector_data = cur_output.get("collector").copy()
282+
next_output.set("collector", collector_data)
283+
284+
# if the loop is interrupted
285+
self._shuttle = next_output
286+
collected_steps += next_output.numel()
287+
for i, (_data, queue) in enumerate(
288+
zip(cur_output.unbind(0), self._yield_queues)
289+
):
274290
queue.append(_data)
275291
dones[i] = _data["next", "done"].any()
276292
if dones.any():
@@ -290,3 +306,61 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol
290306

291307
result = self._trajectory_queue.popleft()
292308
return result
309+
310+
started = False
311+
312+
def _rollout_yield_trajs_async(
313+
self,
314+
) -> TensorDictBase: # A simplified version of rollout
315+
if not self.started:
316+
next_output = self._shuttle
317+
env_input = self.policy(next_output)
318+
self.env.async_step_and_maybe_reset_send(env_input)
319+
self.started = True
320+
321+
collected_steps = 0
322+
dones = torch.zeros(self.env.batch_size, dtype=torch.bool)
323+
while True:
324+
if self._trajectory_queue:
325+
break
326+
327+
cur_output, next_output = self.env.async_step_and_maybe_reset_recv()
328+
329+
# Get the env ids
330+
env_ids = cur_output.get(self.env._env_idx_key).tolist()
331+
332+
# carry over collector data without messing up devices
333+
self._update_traj_ids(cur_output)
334+
335+
collector_data = cur_output.get("collector").copy()
336+
next_output.set("collector", collector_data)
337+
338+
collected_steps += next_output.numel()
339+
dones.fill_(False)
340+
for i, _data in zip(env_ids, cur_output.unbind(0)):
341+
queue = self._yield_queues[i]
342+
queue.append(_data)
343+
dones[i] = _data["next", "done"].any()
344+
if dones.any():
345+
for idx in dones.nonzero()[0].tolist():
346+
if not self.yield_only_last_steps:
347+
self._trajectory_queue.append(
348+
lazy_stack(self._yield_queues[idx], -1)
349+
)
350+
else:
351+
# FIXME: We need to increment the step count here because iterator() won't
352+
# see the extra steps
353+
# We use lazy-stack because unsqueeze doesn't nest the strings in lists
354+
self._trajectory_queue.append(
355+
lazy_stack([self._yield_queues[idx][-1]])
356+
)
357+
self._yield_queues[idx].clear()
358+
359+
# Launch the next batch:
360+
# FIXME: Add a condition RE number of frames here
361+
if True:
362+
env_input = self.policy(next_output)
363+
self.env.async_step_and_maybe_reset_send(env_input)
364+
365+
result = self._trajectory_queue.popleft()
366+
return result

torchrl/envs/async_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(
219219

220220
output_spec, input_spec = self._setup()
221221
input_spec["full_state_spec"].set(
222-
self._env_idx_key, NonTensor(example_data=0, shape=self.batch_size)
222+
self._env_idx_key, NonTensor(example_data=0, shape=input_spec.shape)
223223
)
224224
self.__dict__["_output_spec"] = output_spec
225225
self.__dict__["_input_spec"] = input_spec

0 commit comments

Comments
 (0)