Skip to content

Commit 79217c8

Browse files
authored
[BugFix] Fix parameter conflict resolution and validation tests in Wrappers (#3114)
1 parent 6dfbaf8 commit 79217c8

File tree

2 files changed

+140
-43
lines changed

2 files changed

+140
-43
lines changed

test/llm/test_wrapper.py

Lines changed: 120 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2705,7 +2705,7 @@ def test_parameter_validation(
27052705
tokenizer=tokenizer,
27062706
input_mode="text",
27072707
generate=True,
2708-
generate_kwargs={"repetition_penalty": 0.5},
2708+
generate_kwargs={"repetition_penalty": 0.0},
27092709
)
27102710

27112711
# Test conflicting do_sample and temperature
@@ -2725,36 +2725,129 @@ def test_parameter_validation(
27252725
[vLLMWrapper, TransformersWrapperMaxTokens],
27262726
ids=["vllm", "transformers"],
27272727
)
2728-
def test_parameter_conflict_errors(
2728+
def test_batching_null_dimension(
27292729
self, wrapper_class, vllm_instance, transformers_instance
27302730
):
2731-
"""Test that parameter conflicts are properly detected."""
2732-
model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2733-
tokenizer = model.get_tokenizer() if hasattr(model, "get_tokenizer") else None
2731+
"""Test that null dimension inputs (batch_dims=0) work correctly.
27342732
2735-
# Test conflicting max_tokens and max_new_tokens
2736-
with pytest.raises(
2737-
ValueError, match="Cannot specify both 'max_tokens'.*'max_new_tokens'"
2738-
):
2739-
wrapper_class(
2740-
model,
2741-
tokenizer=tokenizer,
2742-
input_mode="text",
2743-
generate=True,
2744-
generate_kwargs={"max_tokens": 10, "max_new_tokens": 20},
2745-
)
2733+
This test specifically verifies the fix for handling TensorDicts with batch_dims=0
2734+
in the batching decorator, ensuring proper squeeze operation and result handling.
2735+
"""
2736+
# Handle the case where vLLM is not available
2737+
if wrapper_class == vLLMWrapper:
2738+
try:
2739+
model, tokenizer = vllm_instance
2740+
except Exception as e:
2741+
if "vLLM compatibility issue" in str(e):
2742+
pytest.skip("vLLM not available due to compatibility issues")
2743+
raise
2744+
else:
2745+
model, tokenizer = transformers_instance
27462746

2747-
# Test conflicting n and num_return_sequences
2748-
with pytest.raises(
2749-
ValueError, match="Cannot specify both 'n'.*'num_return_sequences'"
2750-
):
2751-
wrapper_class(
2752-
model,
2753-
tokenizer=tokenizer,
2754-
input_mode="text",
2755-
generate=True,
2756-
generate_kwargs={"n": 1, "num_return_sequences": 2},
2757-
)
2747+
# Test without batching first to verify basic functionality
2748+
wrapper_no_batch = wrapper_class(
2749+
model,
2750+
tokenizer=tokenizer,
2751+
input_mode="text",
2752+
generate=True,
2753+
return_log_probs=True,
2754+
# No batching parameters to avoid batching issues
2755+
)
2756+
2757+
# Test 1: Single null dimension input should work
2758+
# This is the key test case - a TensorDict without batch dimensions
2759+
null_dim_input = TensorDict(
2760+
text=Text(prompt="Single question without batch dimension?"),
2761+
batch_size=(), # Empty tuple means no batch dimension
2762+
)
2763+
2764+
result_null = wrapper_no_batch(null_dim_input)
2765+
2766+
# Verify the result structure
2767+
assert "text" in result_null
2768+
assert "tokens" in result_null
2769+
assert "masks" in result_null
2770+
assert "log_probs" in result_null
2771+
2772+
# Verify the result has the expected shape (should maintain null dimension)
2773+
assert result_null.batch_size == ()
2774+
assert isinstance(
2775+
result_null["text"].prompt, str
2776+
) # Should be a single string, not a list
2777+
2778+
# Test 2: Batch input should work normally
2779+
batch_input = TensorDict(
2780+
text=Text(prompt=["Question 1?", "Question 2?"]),
2781+
batch_size=(2,),
2782+
)
2783+
2784+
result_batch = wrapper_no_batch(batch_input)
2785+
assert result_batch.batch_size == (2,)
2786+
assert isinstance(result_batch["text"].prompt, list)
2787+
assert len(result_batch["text"].prompt) == 2
2788+
2789+
# Test 3: Test with batching enabled but with min_batch_size=1 to avoid complex batching
2790+
wrapper_with_batch = wrapper_class(
2791+
model,
2792+
tokenizer=tokenizer,
2793+
input_mode="text",
2794+
generate=True,
2795+
return_log_probs=True,
2796+
min_batch_size=1, # Set to 1 to avoid complex batching scenarios
2797+
)
2798+
2799+
# Test null dimension with batching enabled
2800+
result_null_batch = wrapper_with_batch(null_dim_input)
2801+
2802+
# Verify the result structure
2803+
assert "text" in result_null_batch
2804+
assert "tokens" in result_null_batch
2805+
assert "masks" in result_null_batch
2806+
assert "log_probs" in result_null_batch
2807+
2808+
# Verify the result has the expected shape (should maintain null dimension)
2809+
assert result_null_batch.batch_size == ()
2810+
assert isinstance(
2811+
result_null_batch["text"].prompt, str
2812+
) # Should be a single string, not a list
2813+
2814+
# Test 4: Verify that the _batching decorator correctly handles the squeeze logic
2815+
# This tests the specific fix in the _batching decorator
2816+
from torchrl.modules.llm.policies.common import _batching
2817+
2818+
# Create a simple mock function to test the decorator
2819+
def mock_forward(self, td_input, **kwargs):
2820+
# Return the input as-is for testing
2821+
return td_input
2822+
2823+
# Apply the batching decorator
2824+
batched_mock = _batching(mock_forward)
2825+
2826+
# Create a mock self object with batching attributes
2827+
class MockSelf:
2828+
def __init__(self):
2829+
self._min_batch_size = 1
2830+
self._max_batch_size = None
2831+
self._batch_queue = []
2832+
self._futures = []
2833+
self._batching_lock = type(
2834+
"MockLock",
2835+
(),
2836+
{
2837+
"__enter__": lambda self: None,
2838+
"__exit__": lambda self, *args: None,
2839+
},
2840+
)()
2841+
self._batching_timeout = 10.0
2842+
2843+
mock_self = MockSelf()
2844+
2845+
# Test the decorator with null dimension input
2846+
result = batched_mock(mock_self, null_dim_input)
2847+
2848+
# The result should be the same as the input since our mock just returns the input
2849+
assert result.batch_size == ()
2850+
assert result["text"].prompt == "Single question without batch dimension?"
27582851

27592852

27602853
if __name__ == "__main__":

torchrl/modules/llm/policies/common.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,12 @@ class LLMWrapperBase(TensorDictModuleBase):
384384
**Parameter Conflict Resolution:**
385385
386386
When both legacy (backend-specific) and standardized parameter names are provided,
387-
a :exc:`ValueError` is raised to prevent confusion. For example:
387+
the legacy names silently prevail. This ensures backward compatibility with existing code.
388388
389-
* If both ``max_tokens`` and ``max_new_tokens`` are passed, an error is raised
390-
* If both ``n`` and ``num_return_sequences`` are passed, an error is raised
389+
* If both ``max_tokens`` and ``max_new_tokens`` are passed, ``max_tokens`` wins
390+
* If both ``n`` and ``num_return_sequences`` are passed, ``n`` wins
391391
392-
This ensures clear parameter usage and prevents unexpected behavior.
392+
This behavior allows existing code to continue working without modification.
393393
394394
**Parameter Validation:**
395395
@@ -520,35 +520,34 @@ def _standardize_generate_kwargs(cls, generate_kwargs: dict | None) -> dict:
520520
* vLLM's ``max_tokens`` -> ``max_new_tokens``
521521
* vLLM's ``n`` -> ``num_return_sequences``
522522
523+
**Parameter Conflict Resolution:**
524+
525+
When both legacy (backend-specific) and standardized parameter names are provided,
526+
the legacy names silently prevail. This ensures backward compatibility with existing code.
527+
523528
Args:
524529
generate_kwargs: The generation parameters to standardize
525530
526531
Returns:
527532
Standardized generation parameters
528-
529-
Raises:
530-
ValueError: If conflicting parameter names are provided
531533
"""
532534
if generate_kwargs is None:
533535
return {}
534536

535537
standardized = dict(generate_kwargs)
536538

537539
# Convert vLLM parameter names to common names
540+
# Legacy names prevail in conflicts (backward compatibility)
538541
if "max_tokens" in standardized:
539542
if "max_new_tokens" in standardized:
540-
raise ValueError(
541-
"Cannot specify both 'max_tokens' (legacy vLLM) and 'max_new_tokens' (standardized). "
542-
"Use 'max_new_tokens' for cross-backend compatibility."
543-
)
543+
# Legacy name wins - remove the standardized name
544+
standardized.pop("max_new_tokens")
544545
standardized["max_new_tokens"] = standardized.pop("max_tokens")
545546

546547
if "n" in standardized:
547548
if "num_return_sequences" in standardized:
548-
raise ValueError(
549-
"Cannot specify both 'n' (legacy vLLM) and 'num_return_sequences' (standardized). "
550-
"Use 'num_return_sequences' for cross-backend compatibility."
551-
)
549+
# Legacy name wins - remove the standardized name
550+
standardized.pop("num_return_sequences")
552551
standardized["num_return_sequences"] = standardized.pop("n")
553552

554553
# Validate parameter combinations
@@ -1259,7 +1258,9 @@ def _batched_func(self, td_input: TensorDictBase, **kwargs):
12591258
max_batch_size = getattr(self, "_max_batch_size", None)
12601259
if min_batch_size is not None or max_batch_size is not None:
12611260
# put elements in a queue until the batch size is reached
1261+
squeeze = False
12621262
if td_input.batch_dims == 0:
1263+
squeeze = True
12631264
inputs = [td_input]
12641265
else:
12651266
if td_input.batch_dims > 1:
@@ -1339,7 +1340,10 @@ def _batched_func(self, td_input: TensorDictBase, **kwargs):
13391340
if not future.done():
13401341
future.set_exception(e)
13411342
raise
1342-
1343+
if squeeze:
1344+
if len(futures) > 1:
1345+
raise RuntimeError("More results than expected")
1346+
return futures[0].result()
13431347
return lazy_stack([future.result() for future in futures])
13441348
return func(self, td_input, **kwargs)
13451349

0 commit comments

Comments
 (0)