Skip to content

Commit c651b7b

Browse files
authored
[misc] fix: support nested datastructure in dataproto to convert to tensordict (verl-project#4296)
## What does this PR do? Fixes `ValueError: TensorDict conversion only supports... Got <class 'list'>` when converting `DataProto` with nested non-tensor data to `TensorDict`. **Problem:** Agent loop workflows with nested structures (lists of lists, lists of dicts) in `non_tensor_batch` failed during `to_tensordict()` conversion: - `turn_scores`: `[[], [0.5, 0.8]]` - lists of varying lengths - `reward_extra_info`: `[{"acc": 1.0}, {"acc": 0.0}]` - lists of dicts - `raw_prompt`: `[[{"content": "...", "role": "user"}]]` - lists of lists of dicts - `tool_rewards`: `[[0.0], []]` - lists of lists **Solution:** Wrap nested data in `NonTensorStack` (TensorDict's supported type for non-tensor sequences) instead of converting to plain Python lists. **Impact:** Enables agent loop and multi-turn dialogue workflows to use DataProto ↔ TensorDict conversions without errors. --- ## Test Added 5 comprehensive tests in `tests/test_protocol_on_cpu.py`: 1. **`test_to_tensordict_with_nested_lists`** - Lists of lists (e.g., `turn_scores`) 2. **`test_to_tensordict_with_nested_dicts`** - Lists of dicts (e.g., `reward_extra_info`) 3. **`test_to_tensordict_with_complex_nested_structures`** - Lists of lists of dicts (e.g., `raw_prompt`) 4. **`test_to_tensordict_and_back_with_nested_data`** - Round-trip data integrity 5. **`test_to_tensordict_agent_loop_scenario`** - Real-world agent loop scenario with all nested types All tests verify: - ✅ No conversion errors - ✅ Data accessibility and correctness - ✅ Round-trip conversion preserves data Run tests: ```bash pytest tests/test_protocol_on_cpu.py -k "test_to_tensordict" -v ``` --- ## Design & Code Changes ### Modified Files **1. `verl/protocol.py` (lines 1118-1133)** ```python # Before: Plain list conversion (fails for nested structures) tensor_batch[key] = val.tolist() # After: Wrap in NonTensorStack from tensordict.tensorclass import NonTensorData, NonTensorStack tensor_batch[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) ``` **2. `verl/utils/tensordict_utils.py` (lines 109-127)** ```python # Add validation skip for NonTensorStack objects (already properly formatted) if isinstance(val, NonTensorStack): if batch_size is None: batch_size = len(val) continue ``` ### Why This Works - `NonTensorStack` is TensorDict's native type for storing sequences of arbitrary Python objects - Preserves nested structures (lists, dicts, complex objects) without serialization - Maintains batch semantics - each element corresponds to one batch sample - Enables round-trip conversion: `DataProto → TensorDict → DataProto` without data loss --- ## Checklist Before Submitting - [x] Read the Contribute Guide - [x] Apply pre-commit checks - [ ] Add/Update documentation (if needed - this is a bug fix, not new API) - [x] Add unit tests covering all code paths (5 new tests added) - [ ] CI request (ready for review) --- ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: petersh6 <petershengwhu@gmail.com>
1 parent d0997d2 commit c651b7b

File tree

3 files changed

+226
-1
lines changed

3 files changed

+226
-1
lines changed

tests/test_protocol_on_cpu.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,217 @@ def test_from_tensordict():
779779
assert data.meta_info["name"] == "abdce"
780780

781781

782+
@pytest.mark.skipif(
783+
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
784+
)
785+
def test_to_tensordict_with_nested_lists():
786+
"""Test converting DataProto with nested lists to TensorDict (lists of lists)."""
787+
obs = torch.tensor([1, 2, 3])
788+
# Simulate turn_scores or tool_rewards: array of lists with varying lengths
789+
turn_scores = [[], [0.5, 0.8], [0.9]]
790+
791+
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"turn_scores": turn_scores})
792+
793+
# This should not raise an error
794+
tensordict_output = data.to_tensordict()
795+
796+
# Verify the data is preserved
797+
assert torch.all(torch.eq(tensordict_output["obs"], obs)).item()
798+
# Verify nested structure is accessible (TensorDict wraps NonTensorStack as LinkedList)
799+
retrieved_scores = tensordict_output["turn_scores"]
800+
assert len(retrieved_scores) == len(turn_scores)
801+
# Verify content matches
802+
assert list(retrieved_scores[0]) == []
803+
assert list(retrieved_scores[1]) == [0.5, 0.8]
804+
assert list(retrieved_scores[2]) == [0.9]
805+
806+
807+
@pytest.mark.skipif(
808+
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
809+
)
810+
def test_to_tensordict_with_nested_dicts():
811+
"""Test converting DataProto with lists of dicts to TensorDict."""
812+
obs = torch.tensor([1, 2, 3])
813+
# Simulate reward_extra_info: array of dicts
814+
reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}]
815+
816+
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"reward_extra_info": reward_extra_info})
817+
818+
# This should not raise an error - this was the original bug
819+
tensordict_output = data.to_tensordict()
820+
821+
# Verify the data is preserved
822+
assert torch.all(torch.eq(tensordict_output["obs"], obs)).item()
823+
# Verify nested dicts are accessible
824+
retrieved_info = tensordict_output["reward_extra_info"]
825+
assert len(retrieved_info) == len(reward_extra_info)
826+
# Verify content matches
827+
for i, expected_dict in enumerate(reward_extra_info):
828+
assert dict(retrieved_info[i]) == expected_dict
829+
830+
831+
@pytest.mark.skipif(
832+
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
833+
)
834+
def test_to_tensordict_with_complex_nested_structures():
835+
"""Test converting DataProto with complex nested structures (lists of lists of dicts)."""
836+
obs = torch.tensor([1, 2, 3])
837+
# Simulate raw_prompt: array of lists containing dicts
838+
raw_prompt = [
839+
[{"content": "Question 1", "role": "user"}],
840+
[{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}],
841+
[{"content": "Question 3", "role": "user"}],
842+
]
843+
844+
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"raw_prompt": raw_prompt})
845+
846+
# This should not raise an error
847+
tensordict_output = data.to_tensordict()
848+
849+
# Verify the data is preserved
850+
assert torch.all(torch.eq(tensordict_output["obs"], obs)).item()
851+
# Verify complex nested structure is accessible
852+
retrieved_prompt = tensordict_output["raw_prompt"]
853+
assert len(retrieved_prompt) == len(raw_prompt)
854+
# Spot check: verify first prompt has correct structure
855+
assert len(retrieved_prompt[0]) == 1
856+
assert dict(retrieved_prompt[0][0]) == {"content": "Question 1", "role": "user"}
857+
858+
859+
@pytest.mark.skipif(
860+
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
861+
)
862+
def test_to_tensordict_and_back_with_nested_data():
863+
"""Test round-trip conversion: DataProto → TensorDict → DataProto with nested structures."""
864+
obs = torch.tensor([1, 2, 3, 4])
865+
labels = ["a", "b", "c", "d"]
866+
867+
# Multiple types of nested structures
868+
turn_scores = [[], [0.5], [0.8, 0.9], [0.7]]
869+
reward_extra_info = [
870+
{"acc": 1.0, "loss": 0.1},
871+
{"acc": 0.5, "loss": 0.3},
872+
{"acc": 1.0, "loss": 0.05},
873+
{"acc": 0.0, "loss": 0.9},
874+
]
875+
raw_prompt = [
876+
[{"content": "Q1", "role": "user"}],
877+
[{"content": "Q2", "role": "user"}],
878+
[{"content": "Q3", "role": "user"}, {"content": "A3", "role": "assistant"}],
879+
[{"content": "Q4", "role": "user"}],
880+
]
881+
882+
# Create original DataProto
883+
original_data = DataProto.from_dict(
884+
tensors={"obs": obs},
885+
non_tensors={
886+
"labels": labels,
887+
"turn_scores": turn_scores,
888+
"reward_extra_info": reward_extra_info,
889+
"raw_prompt": raw_prompt,
890+
},
891+
meta_info={"experiment": "test_nested"},
892+
)
893+
894+
# Convert to TensorDict
895+
tensordict_output = original_data.to_tensordict()
896+
897+
# Convert back to DataProto
898+
reconstructed_data = DataProto.from_tensordict(tensordict_output)
899+
900+
# Verify tensors are preserved
901+
assert torch.all(torch.eq(reconstructed_data.batch["obs"], obs)).item()
902+
903+
# Verify non-tensor data is preserved
904+
assert reconstructed_data.non_tensor_batch["labels"].tolist() == labels
905+
906+
# Verify nested structures are preserved
907+
assert len(reconstructed_data.non_tensor_batch["turn_scores"]) == len(turn_scores)
908+
for orig, recon in zip(turn_scores, reconstructed_data.non_tensor_batch["turn_scores"], strict=True):
909+
assert list(orig) == list(recon)
910+
911+
assert len(reconstructed_data.non_tensor_batch["reward_extra_info"]) == len(reward_extra_info)
912+
for orig, recon in zip(reward_extra_info, reconstructed_data.non_tensor_batch["reward_extra_info"], strict=True):
913+
assert orig == recon
914+
915+
assert len(reconstructed_data.non_tensor_batch["raw_prompt"]) == len(raw_prompt)
916+
for orig, recon in zip(raw_prompt, reconstructed_data.non_tensor_batch["raw_prompt"], strict=True):
917+
assert orig == list(recon)
918+
919+
# Verify meta_info is preserved
920+
assert reconstructed_data.meta_info["experiment"] == "test_nested"
921+
922+
923+
@pytest.mark.skipif(
924+
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
925+
)
926+
def test_to_tensordict_agent_loop_scenario():
927+
"""Test the exact scenario from agent loop: DataProto with tool rewards, acc, etc.
928+
929+
This test reproduces the exact error from the agent loop where nested structures
930+
(lists of lists, lists of dicts) failed to convert to TensorDict.
931+
"""
932+
# Simulate real agent loop data structure
933+
prompts = torch.tensor([[1, 2, 3], [4, 5, 6]])
934+
responses = torch.tensor([[7, 8], [9, 10]])
935+
936+
# Non-tensor data with nested structures from agent loop
937+
data_source = ["lighteval/MATH", "lighteval/MATH"]
938+
uid = ["uuid-1", "uuid-2"]
939+
num_turns = np.array([2, 4], dtype=np.int32)
940+
acc = np.array([1.0, 0.0])
941+
turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths
942+
reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}] # List of dicts
943+
raw_prompt = [
944+
[{"content": "Compute 4 @ 2", "role": "user"}],
945+
[{"content": "Compute 8 @ 7", "role": "user"}],
946+
]
947+
tool_rewards = [[0.0], []] # List of lists
948+
949+
data = DataProto.from_dict(
950+
tensors={"prompts": prompts, "responses": responses},
951+
non_tensors={
952+
"data_source": data_source,
953+
"uid": uid,
954+
"num_turns": num_turns,
955+
"acc": acc,
956+
"turn_scores": turn_scores,
957+
"reward_extra_info": reward_extra_info,
958+
"raw_prompt": raw_prompt,
959+
"tool_rewards": tool_rewards,
960+
},
961+
meta_info={"global_steps": 42},
962+
)
963+
964+
# THE KEY TEST: This should not raise ValueError about TensorDict conversion
965+
tensordict_output = data.to_tensordict()
966+
967+
# Verify tensors are accessible
968+
assert torch.all(torch.eq(tensordict_output["prompts"], prompts)).item()
969+
assert torch.all(torch.eq(tensordict_output["responses"], responses)).item()
970+
971+
# Verify all nested structures are accessible (content check, not type check)
972+
assert len(tensordict_output["turn_scores"]) == 2
973+
assert list(tensordict_output["turn_scores"][0]) == []
974+
assert list(tensordict_output["turn_scores"][1]) == [0.5, 0.8]
975+
976+
assert len(tensordict_output["reward_extra_info"]) == 2
977+
assert dict(tensordict_output["reward_extra_info"][0]) == {"acc": 1.0}
978+
979+
assert len(tensordict_output["raw_prompt"]) == 2
980+
assert dict(tensordict_output["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"}
981+
982+
assert len(tensordict_output["tool_rewards"]) == 2
983+
assert list(tensordict_output["tool_rewards"][0]) == [0.0]
984+
assert list(tensordict_output["tool_rewards"][1]) == []
985+
986+
# Verify round-trip conversion works perfectly
987+
reconstructed = DataProto.from_tensordict(tensordict_output)
988+
assert len(reconstructed) == 2
989+
assert reconstructed.meta_info["global_steps"] == 42
990+
assert torch.all(torch.eq(reconstructed.batch["prompts"], prompts)).item()
991+
992+
782993
def test_serialize_deserialize_single_tensor():
783994
"""Test serialization and deserialization of a single tensor"""
784995
# Create test tensor

verl/protocol.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,14 +1118,17 @@ def to_tensordict(self) -> TensorDict:
11181118
tensor_batch = self.batch.to_dict()
11191119
non_tensor_batch = self.non_tensor_batch
11201120

1121+
from tensordict.tensorclass import NonTensorData, NonTensorStack
1122+
11211123
from verl.utils import tensordict_utils as tu
11221124

11231125
common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys())
11241126
assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}"
11251127

11261128
for key, val in non_tensor_batch.items():
11271129
assert isinstance(val, np.ndarray)
1128-
tensor_batch[key] = val.tolist()
1130+
# Convert to NonTensorStack instead of plain list to handle nested structures
1131+
tensor_batch[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])
11291132
output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info)
11301133
return output
11311134

verl/utils/tensordict_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
110110
if isinstance(val, torch.Tensor) and val.is_nested:
111111
assert val.is_contiguous(), "Nested tensors must be contiguous. Try setting layout=torch.jagged"
112112

113+
# Skip validation for NonTensorStack as it's already properly formatted
114+
if isinstance(val, NonTensorStack):
115+
if batch_size is None:
116+
batch_size = len(val)
117+
else:
118+
assert len(val) == batch_size, (
119+
f"Batch size of NonTensorStack {key} is not consistent with other tensors. "
120+
f"Expected {batch_size}, got {len(val)}"
121+
)
122+
continue
123+
113124
if isinstance(val, list):
114125
for v in val:
115126
assert not isinstance(v, torch.Tensor), (

0 commit comments

Comments
 (0)