Skip to content

Commit 3122631

Browse files
vermouth1992PeterSH6gemini-code-assist[bot]
authored
[misc] feat: add more utils of tensordict (verl-project#4322)
### What does this PR do? - Add get/get_keys/pop/pop_keys of tensordict ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Guangming Sheng <petershengwhu@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent f623c14 commit 3122631

File tree

2 files changed

+109
-11
lines changed

2 files changed

+109
-11
lines changed

tests/test_protocol_v2_on_cpu.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,65 @@ def test_chunk_concat():
328328

329329

330330
def test_pop():
331-
obs = torch.randn(100, 10)
332-
act = torch.randn(100, 3)
333-
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
331+
obs = torch.randn(3, 10)
332+
act = torch.randn(3, 3)
333+
labels = ["a", ["b"], []]
334+
dataset = tu.get_tensordict({"obs": obs, "act": act, "labels": labels}, non_tensor_dict={"2": 2, "1": 1})
335+
336+
dataset1 = copy.deepcopy(dataset)
337+
338+
# test pop keys
339+
popped_dataset = tu.pop_keys(dataset, keys=["obs", "2"])
340+
341+
assert popped_dataset.batch_size[0] == 3
342+
343+
assert popped_dataset.keys() == {"obs", "2"}
344+
assert torch.all(torch.eq(popped_dataset["obs"], obs)).item()
345+
assert popped_dataset["2"] == 2
346+
347+
assert dataset.keys() == {"act", "1", "labels"}
348+
349+
# test pop non-exist key
350+
with pytest.raises(KeyError):
351+
tu.pop_keys(dataset, keys=["obs", "2"])
352+
353+
# test single pop
354+
# NonTensorData
355+
assert tu.pop(dataset1, key="2") == 2
356+
# NonTensorStack
357+
assert tu.pop(dataset1, key="labels") == ["a", ["b"], []]
358+
# Tensor
359+
assert torch.all(torch.eq(tu.pop(dataset1, key="obs"), obs)).item()
360+
361+
362+
def test_get():
363+
obs = torch.randn(3, 10)
364+
act = torch.randn(3, 3)
365+
labels = ["a", ["b"], []]
366+
dataset = tu.get_tensordict({"obs": obs, "act": act, "labels": labels}, non_tensor_dict={"2": 2, "1": 1})
367+
368+
# test pop keys
369+
popped_dataset = tu.get_keys(dataset, keys=["obs", "2"])
370+
371+
assert popped_dataset.batch_size[0] == 3
334372

335-
poped_dataset = tu.pop(dataset, keys=["obs", "2"])
373+
assert torch.all(torch.eq(popped_dataset["obs"], dataset["obs"])).item()
336374

337-
assert poped_dataset.batch_size[0] == 100
375+
assert popped_dataset["2"] == dataset["2"]
338376

339-
assert poped_dataset.keys() == {"obs", "2"}
377+
# test pop non-exist key
378+
with pytest.raises(KeyError):
379+
tu.get_keys(dataset, keys=["obs", "3"])
340380

341-
assert dataset.keys() == {"act", "1"}
381+
# test single pop
382+
# NonTensorData
383+
assert tu.get(dataset, key="2") == 2
384+
# NonTensorStack
385+
assert tu.get(dataset, key="labels") == ["a", ["b"], []]
386+
# Tensor
387+
assert torch.all(torch.eq(tu.get(dataset, key="obs"), obs)).item()
388+
# Non-exist key
389+
assert tu.get(dataset, key="3", default=3) == 3
342390

343391

344392
def test_repeat():
@@ -531,7 +579,7 @@ def test_dataproto_no_batch():
531579
selected = data.select("labels")
532580

533581
assert selected["labels"] == labels
534-
pop_data = tu.pop(data, keys=["labels"])
582+
pop_data = tu.pop_keys(data, keys=["labels"])
535583
assert pop_data["labels"] == labels
536584
assert "labels" not in data
537585

verl/utils/tensordict_utils.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Iterator
16+
from typing import Any, Iterable
1717

1818
import torch
1919
from tensordict import TensorDict
@@ -256,7 +256,8 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
256256
)
257257
for key in tensor_dict2.keys():
258258
if key not in tensor_dict1.keys():
259-
tensor_dict1[key] = tensor_dict2[key]
259+
# Note that there is a difference between tensor_dict2[key] and tensor_dict2.get(key)
260+
tensor_dict1[key] = tensor_dict2.get(key)
260261
else:
261262
if isinstance(tensor_dict2[key], torch.Tensor):
262263
assert tensor_dict1[key].equal(tensor_dict2[key]), (
@@ -325,10 +326,59 @@ def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):
325326
assert val == val2
326327

327328

328-
def pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict:
329+
def get(tensordict: TensorDict, key: str, default=None) -> Any:
330+
if key not in tensordict:
331+
return default
332+
333+
output = tensordict.get(key)
334+
if isinstance(output, torch.Tensor):
335+
return output
336+
elif isinstance(output, NonTensorStack):
337+
return output.tolist()
338+
else:
339+
assert isinstance(output, NonTensorData)
340+
return output.data
341+
342+
343+
def get_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict:
344+
tensor_output = {}
345+
non_tensor_output = {}
346+
for key in keys:
347+
if key not in tensordict.keys():
348+
raise KeyError(f"key {key} not in tensordict")
349+
output = tensordict.get(key)
350+
if isinstance(output, torch.Tensor):
351+
tensor_output[key] = output
352+
elif isinstance(output, NonTensorStack):
353+
tensor_output[key] = output.tolist()
354+
else:
355+
assert isinstance(output, NonTensorData)
356+
non_tensor_output[key] = output.data
357+
358+
return get_tensordict(tensor_output, non_tensor_output)
359+
360+
361+
def pop(tensordict: TensorDict, key: str, default=None) -> Any:
362+
_sentinel = object()
363+
output = tensordict.pop(key, _sentinel)
364+
if output is _sentinel:
365+
return default
366+
367+
if isinstance(output, torch.Tensor):
368+
return output
369+
elif isinstance(output, NonTensorStack):
370+
return output.tolist()
371+
else:
372+
assert isinstance(output, NonTensorData)
373+
return output.data
374+
375+
376+
def pop_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict:
329377
tensor_output = {}
330378
non_tensor_output = {}
331379
for key in keys:
380+
if key not in tensordict.keys():
381+
raise KeyError(f"key {key} not in tensordict")
332382
output = tensordict.get(key)
333383
if isinstance(output, torch.Tensor):
334384
tensor_output[key] = tensordict.pop(key)

0 commit comments

Comments
 (0)