Skip to content

Commit 8ce11a8

Browse files
author
Vincent Moens
committed
[Feature] History.default_spec
ghstack-source-id: 40b8a49 Pull Request resolved: #2894
1 parent 4ba5066 commit 8ce11a8

File tree

3 files changed

+97
-1
lines changed

3 files changed

+97
-1
lines changed

test/test_cost.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16676,6 +16676,9 @@ def forward(self, td, mode):
1667616676

1667716677

1667816678
class TestPPO4LLMs:
16679+
@pytest.mark.skipif(
16680+
not _has_transformers, reason="transformers lib required to test PPO with LLMs"
16681+
)
1667916682
@set_capture_non_tensor_stack(False)
1668016683
@pytest.mark.parametrize("from_text", [True, False])
1668116684
def test_hf(self, from_text):

test/test_rb.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4097,6 +4097,22 @@ def test_history_template_recover(self, mock_history, tokenizer):
40974097
)
40984098
recovered = history._inv_chatml(tokenizer.batch_decode(data_token)[0])
40994099

4100+
def test_history_spec(self):
4101+
history = History(
4102+
role=["system", "user", "assistant", "user"],
4103+
content=[
4104+
"i'm the system",
4105+
"i'm the user",
4106+
"I'm the assistant",
4107+
"I'm the user again",
4108+
],
4109+
)
4110+
spec = history.default_spec()
4111+
r = spec.zero()
4112+
assert isinstance(r, History)
4113+
assert spec.is_in(r)
4114+
assert spec.is_in(history)
4115+
41004116

41014117
if __name__ == "__main__":
41024118
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/llm/chat.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import dataclasses
8+
79
import re
10+
from typing import Literal
811

912
import torch
1013

@@ -107,10 +110,11 @@ def apply_chat_template(
107110
padding: bool | str = False,
108111
truncation: bool | str = False,
109112
return_tensors: str | None = "pt",
113+
**kwargs,
110114
):
111115
"""Applies a chat template to the history.
112116
113-
Args:
117+
Keyword Args:
114118
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use.
115119
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to True.
116120
chat_template (str, optional): The chat template to use. Defaults to _TEMPLATES["chatml_format"].
@@ -119,6 +123,7 @@ def apply_chat_template(
119123
padding (bool | str, optional): The padding strategy to use. Defaults to False.
120124
truncation (bool | str, optional): The truncation strategy to use. Defaults to False.
121125
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
126+
**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
122127
123128
Returns:
124129
The formatted history.
@@ -135,6 +140,17 @@ def apply_chat_template(
135140
continue_final_message=continue_final_message,
136141
)
137142

143+
@classmethod
144+
def inv_chat_template(
145+
cls, text: str, chat_template_name: Literal["chatml_format"] = "chatml_format"
146+
) -> History:
147+
if chat_template_name not in ("chatml_format",):
148+
# Hard coded for now
149+
raise NotImplementedError(
150+
"chat_template_name must be one of ('chatml_format',)"
151+
)
152+
return cls._inv_chatml(text)
153+
138154
@classmethod
139155
def _inv_chatml(cls, text: str) -> History:
140156
"""Inverts a chatml string into a History object.
@@ -227,3 +243,64 @@ def extend(
227243
self.__dict__["_tensordict"] = td
228244
return self
229245
return torch.stack(list(self.unbind(dim)) + list(history.unbind(dim)), dim=dim)
246+
247+
@classmethod
248+
def default_spec(cls, shape=(-1,)):
249+
"""A default spec to use in transforms / envs that return History objects.
250+
251+
Args:
252+
shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length
253+
along time dimension).
254+
255+
Example:
256+
>>> import tensordict
257+
>>> from torchrl.data import History
258+
>>> tensordict.set_list_to_stack(True).set()
259+
>>>
260+
>>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,))
261+
>>> spec = history.default_spec()
262+
>>> print(spec)
263+
Composite(
264+
role: NonTensor(
265+
shape=torch.Size([-1]),
266+
space=None,
267+
device=None,
268+
dtype=None,
269+
domain=None,
270+
example_data=foo),
271+
content: NonTensor(
272+
shape=torch.Size([-1]),
273+
space=None,
274+
device=None,
275+
dtype=None,
276+
domain=None,
277+
example_data=foo),
278+
device=None,
279+
shape=torch.Size([-1]))
280+
>>> print(spec.zero())
281+
History(
282+
content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
283+
role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
284+
batch_size=torch.Size([1]),
285+
device=None,
286+
is_shared=False)
287+
288+
"""
289+
from torchrl.data import Composite, NonTensor
290+
291+
def get_default_value(field):
292+
if field.default is not dataclasses.MISSING:
293+
return field.default
294+
elif field.type in (str, "str"):
295+
return "foo"
296+
else:
297+
return None
298+
299+
defaults = {
300+
k: NonTensor(
301+
example_data=get_default_value(cls.__dataclass_fields__[k]), shape=(-1,)
302+
)
303+
for k in cls.__dataclass_fields__
304+
}
305+
306+
return Composite(defaults, shape=shape, data_cls=cls)

0 commit comments

Comments
 (0)