67
67
from torchrl .data .llm .dataset import _has_transformers
68
68
from torchrl .data .utils import CloudpickleWrapper
69
69
from torchrl .envs import (
70
+ AsyncEnvPool ,
70
71
EnvBase ,
71
72
EnvCreator ,
72
73
InitTracker ,
@@ -3737,10 +3738,12 @@ async def test_llm_collector_start(self, vllm_instance):
3737
3738
def test_llm_collector_completed (
3738
3739
self , vllm_instance_opt , rb , yield_only_last_steps
3739
3740
):
3741
+ torch .manual_seed (0 )
3740
3742
policy = vLLMWrapper (vllm_instance_opt )
3741
3743
tokenizer = vllm_instance_opt .get_tokenizer ()
3742
3744
bsz = 4
3743
3745
total_steps = 20
3746
+ max_steps = 20
3744
3747
dataloader = DummyStrDataLoader (bsz )
3745
3748
3746
3749
env = LLMEnv .from_dataloader (
@@ -3751,7 +3754,7 @@ def test_llm_collector_completed(
3751
3754
eos_token_id = tokenizer .eos_token_id ,
3752
3755
)
3753
3756
# To make sure the env breaks at some point
3754
- env = env .append_transform (StepCounter (max_steps = 100 ))
3757
+ env = env .append_transform (StepCounter (max_steps = max_steps ))
3755
3758
3756
3759
if rb :
3757
3760
rb = ReplayBuffer (storage = LazyStackStorage (max_size = total_steps * 2 ))
@@ -3774,11 +3777,27 @@ def test_llm_collector_completed(
3774
3777
for data in collector :
3775
3778
if rb is None :
3776
3779
assert data .ndim == 1
3777
- assert (data ["next" , "step_count" ] < 99 ).all ()
3780
+ # assert (data["next", "step_count"] < max_steps-1 ).all()
3778
3781
cur_total_steps += data .numel ()
3779
3782
for i in range (data .numel ()):
3780
- # Check that there are more chars in the next step
3781
- assert len (data ["text" ][i ]) < len (data ["next" , "text" ][i ])
3783
+ if data [i ]["next" , "step_count" ] == max_steps :
3784
+ continue
3785
+ if data [i ]["text_response" ]:
3786
+ # Check that there are more chars in the next step
3787
+ assert len (data ["text" ][i ]) < len (data ["next" , "text" ][i ]), (
3788
+ i ,
3789
+ data [i ]["next" , "step_count" ],
3790
+ data [i ]["next" , "done" ],
3791
+ data [i ]["text_response" ],
3792
+ )
3793
+ else :
3794
+ assert len (data ["text" ][i ]) == len (data ["next" , "text" ][i ]), (
3795
+ i ,
3796
+ data [i ]["next" , "step_count" ],
3797
+ data [i ]["next" , "done" ],
3798
+ data [i ]["text_response" ],
3799
+ )
3800
+
3782
3801
if yield_only_last_steps :
3783
3802
assert data .shape == (1 ,)
3784
3803
else :
@@ -3787,8 +3806,137 @@ def test_llm_collector_completed(
3787
3806
assert data is None
3788
3807
sample = rb .sample (5 )
3789
3808
for i in range (sample .numel ()):
3790
- # Check that there are more chars in the next step
3791
- assert len (sample ["text" ][i ]) < len (sample ["next" , "text" ][i ])
3809
+ if sample [i ]["next" , "step_count" ] == max_steps :
3810
+ continue
3811
+ if sample [i ]["text_response" ]:
3812
+ # Check that there are more chars in the next step
3813
+ assert len (sample ["text" ][i ]) < len (
3814
+ sample ["next" , "text" ][i ]
3815
+ ), (
3816
+ i ,
3817
+ sample [i ]["next" , "step_count" ],
3818
+ sample [i ]["next" , "done" ],
3819
+ sample [i ]["text_response" ],
3820
+ )
3821
+ else :
3822
+ assert len (sample ["text" ][i ]) == len (
3823
+ sample ["next" , "text" ][i ]
3824
+ ), (
3825
+ i ,
3826
+ sample [i ]["next" , "step_count" ],
3827
+ sample [i ]["next" , "done" ],
3828
+ sample [i ]["text_response" ],
3829
+ )
3830
+
3831
+ assert sample .ndim == 1
3832
+ assert sample .shape == (5 ,)
3833
+ assert (sample ["next" , "step_count" ] < 99 ).all ()
3834
+ cur_total_steps += 1
3835
+ assert collector ._frames >= cur_total_steps
3836
+ if rb is None and not yield_only_last_steps :
3837
+ assert has_found_one_with_more_steps
3838
+ assert collector ._frames >= total_steps
3839
+
3840
+ @pytest .mark .slow
3841
+ @pytest .mark .parametrize ("rb" , [False , True ])
3842
+ @pytest .mark .parametrize ("yield_only_last_steps" , [False , True ])
3843
+ def test_llm_collector_completed_async (
3844
+ self , vllm_instance_opt , rb , yield_only_last_steps
3845
+ ):
3846
+ torch .manual_seed (0 )
3847
+ policy = vLLMWrapper (vllm_instance_opt )
3848
+ tokenizer = vllm_instance_opt .get_tokenizer ()
3849
+ bsz = 4
3850
+ total_steps = 20
3851
+ max_steps = 20
3852
+ dataloader = DummyStrDataLoader (bsz )
3853
+
3854
+ def env_maker ():
3855
+ env = LLMEnv .from_dataloader (
3856
+ dataloader = dataloader ,
3857
+ str2str = True ,
3858
+ batch_size = (),
3859
+ group_repeats = True ,
3860
+ eos_token_id = tokenizer .eos_token_id ,
3861
+ )
3862
+ # To make sure the env breaks at some point
3863
+ env = env .append_transform (StepCounter (max_steps = max_steps ))
3864
+ return env
3865
+
3866
+ env = AsyncEnvPool ([env_maker ] * bsz , backend = "threading" , stack = "lazy" )
3867
+
3868
+ if rb :
3869
+ rb = ReplayBuffer (storage = LazyStackStorage (max_size = total_steps * 2 ))
3870
+ else :
3871
+ rb = None
3872
+ collector = LLMCollector (
3873
+ env = env ,
3874
+ policy_factory = lambda : policy ,
3875
+ steps_per_batch = env .batch_size [0 ],
3876
+ replay_buffer = rb ,
3877
+ total_steps = total_steps ,
3878
+ yield_completed_trajectories = True ,
3879
+ yield_only_last_steps = yield_only_last_steps ,
3880
+ )
3881
+ assert collector .yield_completed_trajectories
3882
+ assert collector .yield_only_last_steps is yield_only_last_steps
3883
+
3884
+ cur_total_steps = 0
3885
+ has_found_one_with_more_steps = False
3886
+ for data in collector :
3887
+ if rb is None :
3888
+ assert data .ndim == 1
3889
+ # assert (data["next", "step_count"] < max_steps-1).all()
3890
+ cur_total_steps += data .numel ()
3891
+ for i in range (data .numel ()):
3892
+ if data [i ]["next" , "step_count" ] == max_steps :
3893
+ continue
3894
+ if data [i ]["text_response" ]:
3895
+ # Check that there are more chars in the next step
3896
+ assert len (data ["text" ][i ]) < len (data ["next" , "text" ][i ]), (
3897
+ i ,
3898
+ data [i ]["next" , "step_count" ],
3899
+ data [i ]["next" , "done" ],
3900
+ data [i ]["text_response" ],
3901
+ )
3902
+ else :
3903
+ assert len (data ["text" ][i ]) == len (data ["next" , "text" ][i ]), (
3904
+ i ,
3905
+ data [i ]["next" , "step_count" ],
3906
+ data [i ]["next" , "done" ],
3907
+ data [i ]["text_response" ],
3908
+ )
3909
+
3910
+ if yield_only_last_steps :
3911
+ assert data .shape == (1 ,)
3912
+ else :
3913
+ has_found_one_with_more_steps |= data .numel () > 1
3914
+ else :
3915
+ assert data is None
3916
+ sample = rb .sample (5 )
3917
+ for i in range (sample .numel ()):
3918
+ if sample [i ]["next" , "step_count" ] == max_steps :
3919
+ continue
3920
+ if sample [i ]["text_response" ]:
3921
+ # Check that there are more chars in the next step
3922
+ assert len (sample ["text" ][i ]) < len (
3923
+ sample ["next" , "text" ][i ]
3924
+ ), (
3925
+ i ,
3926
+ sample [i ]["next" , "step_count" ],
3927
+ sample [i ]["next" , "done" ],
3928
+ sample [i ]["text_response" ],
3929
+ )
3930
+ else :
3931
+ assert len (sample ["text" ][i ]) == len (
3932
+ sample ["next" , "text" ][i ]
3933
+ ), (
3934
+ i ,
3935
+ sample [i ]["next" , "step_count" ],
3936
+ sample [i ]["next" , "done" ],
3937
+ sample [i ]["text_response" ],
3938
+ )
3939
+
3792
3940
assert sample .ndim == 1
3793
3941
assert sample .shape == (5 ,)
3794
3942
assert (sample ["next" , "step_count" ] < 99 ).all ()
0 commit comments