Skip to content

Commit 5a5f63d

Browse files
authored
[BugFix] Variable length vllm wrapper answer stacking (#3049)
1 parent f287bb3 commit 5a5f63d

File tree

7 files changed

+201
-26
lines changed

7 files changed

+201
-26
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929

3030
TorchRL now includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
3131

32-
- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines
33-
- 💬 **Conversation Management**: Advanced `History` class for multi-turn dialogue with automatic chat template detection
34-
- 🛠️ **Tool Integration**: Built-in support for Python code execution, function calling, and custom tool transforms
35-
- 🎯 **Specialized Objectives**: GRPO (Group Relative Policy Optimization) and SFT loss functions optimized for language models
36-
-**High-Performance Collectors**: Async data collection with distributed training support
32+
- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines - more to come!
33+
- 💬 **Conversation Management**: Advanced [`History`](torchrl/data/llm/history.py) class for multi-turn dialogue with automatic chat template detection
34+
- 🛠️ **Tool Integration**: [Built-in support](torchrl/envs/llm/transforms/) for Python code execution, function calling, and custom tool transforms
35+
- 🎯 **Specialized Objectives**: [GRPO](torchrl/objectives/llm/grpo.py) (Group Relative Policy Optimization) and [SFT](torchrl/objectives/llm/sft.py) loss functions optimized for language models
36+
-**High-Performance Collectors**: [Async data collection](torchrl/collectors/llm/) with distributed training support
3737
- 🔄 **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation
3838

3939
The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!

test/llm/test_data.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,119 @@ def norm(x):
966966
history.role[:-1]
967967
), f"All roles except the last should match original. Original: {history.role[:-1]}, Parsed: {parsed.role[:-1]}"
968968

969+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
970+
def test_extract_responses_from_full_histories_batch_issue(self):
971+
"""Test the isolated function for handling different response shapes in batch processing."""
972+
from torchrl.modules.llm.policies.common import (
973+
_extract_responses_from_full_histories,
974+
)
975+
from transformers import AutoTokenizer
976+
977+
# Create a batch of 2 prompt histories
978+
prompt_histories = History.from_chats(
979+
[
980+
[
981+
{"role": "user", "content": "Hello, how are you?"},
982+
],
983+
[
984+
{"role": "user", "content": "Tell me a joke."},
985+
],
986+
]
987+
)
988+
989+
# Simulate generated text with different response counts
990+
text_full = [
991+
# First element: 1 assistant response
992+
"""<|im_start|>user
993+
Hello, how are you?<|im_end|>
994+
<|im_start|>assistant
995+
I'm doing well, thank you for asking!<|im_end|>""",
996+
# Second element: 3 messages (1 assistant + 1 user + 1 assistant)
997+
"""<|im_start|>user
998+
Tell me a joke.<|im_end|>
999+
<|im_start|>assistant
1000+
Why did the chicken cross the road?<|im_end|>
1001+
<|im_start|>user
1002+
I don't know, why?<|im_end|>
1003+
<|im_start|>assistant
1004+
To get to the other side!<|im_end|>""",
1005+
]
1006+
1007+
# Test the isolated function
1008+
h_responses = _extract_responses_from_full_histories(
1009+
text_full, prompt_histories, chat_template_name="qwen"
1010+
)
1011+
1012+
# Verify the responses have the expected shapes and content
1013+
assert len(h_responses) == 2, f"Expected 2 responses, got {len(h_responses)}"
1014+
1015+
# Check first response (should be padded to match second response length)
1016+
response_0 = h_responses[0]
1017+
assert response_0.shape == (3,), f"Expected shape (3,), got {response_0.shape}"
1018+
assert response_0.role == [
1019+
"assistant",
1020+
"<none>",
1021+
"<none>",
1022+
], f"Expected roles ['assistant', '<none>', '<none>'], got {response_0.role}"
1023+
assert response_0.content == [
1024+
"I'm doing well, thank you for asking!",
1025+
"",
1026+
"",
1027+
], f"Expected content ['I\\'m doing well, thank you for asking!', '', ''], got {response_0.content}"
1028+
1029+
# Check second response (should have 3 messages)
1030+
response_1 = h_responses[1]
1031+
assert response_1.shape == (3,), f"Expected shape (3,), got {response_1.shape}"
1032+
assert response_1.role == [
1033+
"assistant",
1034+
"user",
1035+
"assistant",
1036+
], f"Expected roles ['assistant', 'user', 'assistant'], got {response_1.role}"
1037+
assert response_1.content == [
1038+
"Why did the chicken cross the road?",
1039+
"I don't know, why?",
1040+
"To get to the other side!",
1041+
], f"Expected content ['Why did the chicken cross the road?', 'I don\\'t know, why?', 'To get to the other side!'], got {response_1.content}"
1042+
1043+
assert isinstance(h_responses, History)
1044+
h_responses.shape == (
1045+
2,
1046+
3,
1047+
), f"Expected stacked shape (2, 3), got {h_responses.shape}"
1048+
1049+
# Extract individual responses for testing
1050+
response_0 = h_responses[0]
1051+
response_1 = h_responses[1]
1052+
1053+
# Test chat template application
1054+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
1055+
1056+
# Test first response (should only show the assistant message, ignore padding)
1057+
template_0 = response_0.apply_chat_template(
1058+
tokenizer=tokenizer, add_generation_prompt=False, chat_template_name="qwen"
1059+
)
1060+
expected_0 = """<|im_start|>system
1061+
You are a helpful assistant.<|im_end|>
1062+
<|im_start|>assistant
1063+
I'm doing well, thank you for asking!<|im_end|>
1064+
"""
1065+
assert template_0 == expected_0
1066+
1067+
# Test second response (should show all 3 messages)
1068+
template_1 = response_1.apply_chat_template(
1069+
tokenizer=tokenizer, add_generation_prompt=False, chat_template_name="qwen"
1070+
)
1071+
expected_1 = """<|im_start|>system
1072+
You are a helpful assistant.<|im_end|>
1073+
<|im_start|>assistant
1074+
Why did the chicken cross the road?<|im_end|>
1075+
<|im_start|>user
1076+
I don't know, why?<|im_end|>
1077+
<|im_start|>assistant
1078+
To get to the other side!<|im_end|>
1079+
"""
1080+
assert template_1 == expected_1
1081+
9691082

9701083
class TestTopK:
9711084
@pytest.mark.parametrize("per_token_reward", [True, False])
@@ -989,7 +1102,7 @@ def _per_token_reward(i):
9891102
("next", "done"): torch.full((1, 1), True),
9901103
("next", "reward"): _per_token_reward(i),
9911104
# total of 10 dialogs per prompt
992-
"text": f"Prompt {i // 5}",
1105+
("text", "prompt"): f"Prompt {i // 5}",
9931106
}
9941107
)
9951108
for i in range(50)

torchrl/data/llm/history.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,8 @@ def append(
11661166
Returns:
11671167
History: The appended History object.
11681168
"""
1169+
# TODO: we should remove the <none> role from the history before appending / extending
1170+
# It works when keeping them, but it may lead to a lot of useless padding in between valid messages
11691171
if not self.batch_dims:
11701172
raise RuntimeError(
11711173
"Cannot append an element to a batchless History. Call unsqueeze(dim=0) first on self."

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,9 @@ def add(self, data: Any) -> int:
705705
make_none = False
706706
# Transforms usually expect a time batch dimension when called within a RB, so we unsqueeze the data temporarily
707707
is_tc = is_tensor_collection(data)
708-
with data.unsqueeze(-1) if is_tc else contextlib.nullcontext(data) as data_unsq:
708+
with data.unsqueeze(-1) if is_tc else contextlib.nullcontext(
709+
data
710+
) as data_unsq:
709711
data_unsq_r = self._transform.inv(data_unsq)
710712
if is_tc and data_unsq_r is not None:
711713
# this is a no-op whenever the result matches the input

torchrl/modules/llm/policies/common.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,3 +854,70 @@ def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase:
854854
data = self(data)
855855
return data.get((self.log_prob_key, "response"), **get_kwargs)
856856
raise RuntimeError("log_prob not callable when generate=True.")
857+
858+
859+
def _extract_responses_from_full_histories(
860+
text_full: list[str],
861+
prompt_histories,
862+
chat_template_name: str | None = None,
863+
tokenizer=None,
864+
) -> History:
865+
"""Extract response histories from full text histories.
866+
867+
This function parses the full text back to history objects and extracts
868+
the response portions (everything after the prompt).
869+
870+
Args:
871+
text_full: List of full text strings to parse
872+
prompt_histories: The original prompt histories
873+
chat_template_name: Optional chat template name for parsing
874+
tokenizer: Optional tokenizer for template detection
875+
876+
Returns:
877+
Stacked History object with response portions
878+
879+
Raises:
880+
RuntimeError: If full history is shorter than prompt history
881+
RuntimeError: If parsing produces inconsistent batch shapes
882+
"""
883+
import torch
884+
from tensordict.utils import _zip_strict
885+
from torchrl.data.llm import History
886+
887+
# Extract response portions by processing each element individually
888+
# This avoids the stacking issue when different batch elements produce
889+
# different numbers of responses
890+
response_histories = []
891+
full_histories = History.from_text(
892+
text_full, chat_template_name=chat_template_name, tokenizer=tokenizer
893+
)
894+
for h_prompt, h_full in _zip_strict(
895+
prompt_histories.unbind(0), full_histories.unbind(0)
896+
):
897+
if h_full.shape[0] <= h_prompt.shape[0]:
898+
raise RuntimeError(
899+
f"Full history is shorter than prompt history: {h_full.shape} <= {h_prompt.shape}"
900+
)
901+
# Note: there can be more than one response, so the response has the same number of dims as prompt
902+
response_histories.append(h_full[h_prompt.shape[0] :])
903+
904+
# Check if all responses have the same shape
905+
shapes = [r.shape for r in response_histories]
906+
if len(set(shapes)) > 1:
907+
# Different shapes detected - pad to the same length
908+
max_length = max(r.shape[0] for r in response_histories)
909+
padded_responses = []
910+
for response in response_histories:
911+
if response.shape[0] < max_length:
912+
# Pad with empty messages using "<none>" role
913+
padding_needed = max_length - response.shape[0]
914+
padding_history = History(
915+
role="<none>", content="", batch_size=(padding_needed,)
916+
)
917+
padded_response = response.extend(padding_history, inplace=False)
918+
padded_responses.append(padded_response)
919+
else:
920+
padded_responses.append(response)
921+
return torch.stack(padded_responses)
922+
923+
return torch.stack(response_histories)

torchrl/modules/llm/policies/transformers_wrapper.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.nn.utils.rnn import pad_sequence
2525

2626
from torchrl.modules.llm.policies.common import (
27+
_extract_responses_from_full_histories,
2728
ChatHistory,
2829
LLMWrapperBase,
2930
LogProbs,
@@ -680,16 +681,11 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
680681
for h in history_chat.unbind(1):
681682
h.prompt = history
682683
with history_chat.view(-1) as history_chat_flat:
683-
history_chat_flat.full = full_histories = History.from_text(text_full)
684684
prompt_histories = history_chat_flat.prompt
685-
# iterate over batch
686-
h_responses = []
687-
for h_full, h_prompt in _zip_strict(
688-
full_histories.unbind(0), prompt_histories.unbind(0)
689-
):
690-
if h_full.shape[0] <= h_prompt.shape[0]:
691-
raise RuntimeError("Full history is shorter than prompt history")
692-
h_responses.append(h_full[h_prompt.shape[0] :])
685+
# Extract response histories from full text
686+
h_responses = _extract_responses_from_full_histories(
687+
text_full, prompt_histories, self.chat_template_name, self.tokenizer
688+
)
693689
history_chat_flat.response = torch.stack(h_responses)
694690
result.set(self.history_key, history_chat)
695691
return result

torchrl/modules/llm/policies/vllm_wrapper.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from torchrl.envs.utils import _classproperty
2626
from torchrl.modules.llm.policies.common import (
27+
_extract_responses_from_full_histories,
2728
ChatHistory,
2829
LLMWrapperBase,
2930
LogProbs,
@@ -720,17 +721,11 @@ def _from_vllm_generate_history(
720721
for h in history_chat.unbind(1):
721722
h.prompt = history
722723
with history_chat.view(-1) as history_chat_flat:
723-
history_chat_flat.full = full_histories = History.from_text(text_full)
724724
prompt_histories = history_chat_flat.prompt
725-
# iterate over batch
726-
h_responses = []
727-
for h_full, h_prompt in _zip_strict(
728-
full_histories.unbind(0), prompt_histories.unbind(0)
729-
):
730-
if h_full.shape[0] <= h_prompt.shape[0]:
731-
raise RuntimeError("Full history is shorter than prompt history")
732-
# Note: there can be more than one response, so the response has the same number of dims as prompt
733-
h_responses.append(h_full[h_prompt.shape[0] :])
725+
# Extract response histories from full text
726+
h_responses = _extract_responses_from_full_histories(
727+
text_full, prompt_histories, self.chat_template_name, self.tokenizer
728+
)
734729
history_chat_flat.response = torch.stack(h_responses)
735730
result.set(self.history_key, history_chat)
736731
return result

0 commit comments

Comments
 (0)