@@ -2705,7 +2705,7 @@ def test_parameter_validation(
2705
2705
tokenizer = tokenizer ,
2706
2706
input_mode = "text" ,
2707
2707
generate = True ,
2708
- generate_kwargs = {"repetition_penalty" : 0.5 },
2708
+ generate_kwargs = {"repetition_penalty" : 0.0 },
2709
2709
)
2710
2710
2711
2711
# Test conflicting do_sample and temperature
@@ -2725,36 +2725,129 @@ def test_parameter_validation(
2725
2725
[vLLMWrapper , TransformersWrapperMaxTokens ],
2726
2726
ids = ["vllm" , "transformers" ],
2727
2727
)
2728
- def test_parameter_conflict_errors (
2728
+ def test_batching_null_dimension (
2729
2729
self , wrapper_class , vllm_instance , transformers_instance
2730
2730
):
2731
- """Test that parameter conflicts are properly detected."""
2732
- model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2733
- tokenizer = model .get_tokenizer () if hasattr (model , "get_tokenizer" ) else None
2731
+ """Test that null dimension inputs (batch_dims=0) work correctly.
2734
2732
2735
- # Test conflicting max_tokens and max_new_tokens
2736
- with pytest .raises (
2737
- ValueError , match = "Cannot specify both 'max_tokens'.*'max_new_tokens'"
2738
- ):
2739
- wrapper_class (
2740
- model ,
2741
- tokenizer = tokenizer ,
2742
- input_mode = "text" ,
2743
- generate = True ,
2744
- generate_kwargs = {"max_tokens" : 10 , "max_new_tokens" : 20 },
2745
- )
2733
+ This test specifically verifies the fix for handling TensorDicts with batch_dims=0
2734
+ in the batching decorator, ensuring proper squeeze operation and result handling.
2735
+ """
2736
+ # Handle the case where vLLM is not available
2737
+ if wrapper_class == vLLMWrapper :
2738
+ try :
2739
+ model , tokenizer = vllm_instance
2740
+ except Exception as e :
2741
+ if "vLLM compatibility issue" in str (e ):
2742
+ pytest .skip ("vLLM not available due to compatibility issues" )
2743
+ raise
2744
+ else :
2745
+ model , tokenizer = transformers_instance
2746
2746
2747
- # Test conflicting n and num_return_sequences
2748
- with pytest .raises (
2749
- ValueError , match = "Cannot specify both 'n'.*'num_return_sequences'"
2750
- ):
2751
- wrapper_class (
2752
- model ,
2753
- tokenizer = tokenizer ,
2754
- input_mode = "text" ,
2755
- generate = True ,
2756
- generate_kwargs = {"n" : 1 , "num_return_sequences" : 2 },
2757
- )
2747
+ # Test without batching first to verify basic functionality
2748
+ wrapper_no_batch = wrapper_class (
2749
+ model ,
2750
+ tokenizer = tokenizer ,
2751
+ input_mode = "text" ,
2752
+ generate = True ,
2753
+ return_log_probs = True ,
2754
+ # No batching parameters to avoid batching issues
2755
+ )
2756
+
2757
+ # Test 1: Single null dimension input should work
2758
+ # This is the key test case - a TensorDict without batch dimensions
2759
+ null_dim_input = TensorDict (
2760
+ text = Text (prompt = "Single question without batch dimension?" ),
2761
+ batch_size = (), # Empty tuple means no batch dimension
2762
+ )
2763
+
2764
+ result_null = wrapper_no_batch (null_dim_input )
2765
+
2766
+ # Verify the result structure
2767
+ assert "text" in result_null
2768
+ assert "tokens" in result_null
2769
+ assert "masks" in result_null
2770
+ assert "log_probs" in result_null
2771
+
2772
+ # Verify the result has the expected shape (should maintain null dimension)
2773
+ assert result_null .batch_size == ()
2774
+ assert isinstance (
2775
+ result_null ["text" ].prompt , str
2776
+ ) # Should be a single string, not a list
2777
+
2778
+ # Test 2: Batch input should work normally
2779
+ batch_input = TensorDict (
2780
+ text = Text (prompt = ["Question 1?" , "Question 2?" ]),
2781
+ batch_size = (2 ,),
2782
+ )
2783
+
2784
+ result_batch = wrapper_no_batch (batch_input )
2785
+ assert result_batch .batch_size == (2 ,)
2786
+ assert isinstance (result_batch ["text" ].prompt , list )
2787
+ assert len (result_batch ["text" ].prompt ) == 2
2788
+
2789
+ # Test 3: Test with batching enabled but with min_batch_size=1 to avoid complex batching
2790
+ wrapper_with_batch = wrapper_class (
2791
+ model ,
2792
+ tokenizer = tokenizer ,
2793
+ input_mode = "text" ,
2794
+ generate = True ,
2795
+ return_log_probs = True ,
2796
+ min_batch_size = 1 , # Set to 1 to avoid complex batching scenarios
2797
+ )
2798
+
2799
+ # Test null dimension with batching enabled
2800
+ result_null_batch = wrapper_with_batch (null_dim_input )
2801
+
2802
+ # Verify the result structure
2803
+ assert "text" in result_null_batch
2804
+ assert "tokens" in result_null_batch
2805
+ assert "masks" in result_null_batch
2806
+ assert "log_probs" in result_null_batch
2807
+
2808
+ # Verify the result has the expected shape (should maintain null dimension)
2809
+ assert result_null_batch .batch_size == ()
2810
+ assert isinstance (
2811
+ result_null_batch ["text" ].prompt , str
2812
+ ) # Should be a single string, not a list
2813
+
2814
+ # Test 4: Verify that the _batching decorator correctly handles the squeeze logic
2815
+ # This tests the specific fix in the _batching decorator
2816
+ from torchrl .modules .llm .policies .common import _batching
2817
+
2818
+ # Create a simple mock function to test the decorator
2819
+ def mock_forward (self , td_input , ** kwargs ):
2820
+ # Return the input as-is for testing
2821
+ return td_input
2822
+
2823
+ # Apply the batching decorator
2824
+ batched_mock = _batching (mock_forward )
2825
+
2826
+ # Create a mock self object with batching attributes
2827
+ class MockSelf :
2828
+ def __init__ (self ):
2829
+ self ._min_batch_size = 1
2830
+ self ._max_batch_size = None
2831
+ self ._batch_queue = []
2832
+ self ._futures = []
2833
+ self ._batching_lock = type (
2834
+ "MockLock" ,
2835
+ (),
2836
+ {
2837
+ "__enter__" : lambda self : None ,
2838
+ "__exit__" : lambda self , * args : None ,
2839
+ },
2840
+ )()
2841
+ self ._batching_timeout = 10.0
2842
+
2843
+ mock_self = MockSelf ()
2844
+
2845
+ # Test the decorator with null dimension input
2846
+ result = batched_mock (mock_self , null_dim_input )
2847
+
2848
+ # The result should be the same as the input since our mock just returns the input
2849
+ assert result .batch_size == ()
2850
+ assert result ["text" ].prompt == "Single question without batch dimension?"
2758
2851
2759
2852
2760
2853
if __name__ == "__main__" :
0 commit comments