4
4
# LICENSE file in the root directory of this source tree.
5
5
from __future__ import annotations
6
6
7
+ import dataclasses
8
+
7
9
import re
10
+ from typing import Literal
8
11
9
12
import torch
10
13
@@ -107,10 +110,11 @@ def apply_chat_template(
107
110
padding : bool | str = False ,
108
111
truncation : bool | str = False ,
109
112
return_tensors : str | None = "pt" ,
113
+ ** kwargs ,
110
114
):
111
115
"""Applies a chat template to the history.
112
116
113
- Args:
117
+ Keyword Args:
114
118
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use.
115
119
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to True.
116
120
chat_template (str, optional): The chat template to use. Defaults to _TEMPLATES["chatml_format"].
@@ -119,6 +123,7 @@ def apply_chat_template(
119
123
padding (bool | str, optional): The padding strategy to use. Defaults to False.
120
124
truncation (bool | str, optional): The truncation strategy to use. Defaults to False.
121
125
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
126
+ **kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
122
127
123
128
Returns:
124
129
The formatted history.
@@ -135,6 +140,17 @@ def apply_chat_template(
135
140
continue_final_message = continue_final_message ,
136
141
)
137
142
143
+ @classmethod
144
+ def inv_chat_template (
145
+ cls , text : str , chat_template_name : Literal ["chatml_format" ] = "chatml_format"
146
+ ) -> History :
147
+ if chat_template_name not in ("chatml_format" ,):
148
+ # Hard coded for now
149
+ raise NotImplementedError (
150
+ "chat_template_name must be one of ('chatml_format',)"
151
+ )
152
+ return cls ._inv_chatml (text )
153
+
138
154
@classmethod
139
155
def _inv_chatml (cls , text : str ) -> History :
140
156
"""Inverts a chatml string into a History object.
@@ -227,3 +243,64 @@ def extend(
227
243
self .__dict__ ["_tensordict" ] = td
228
244
return self
229
245
return torch .stack (list (self .unbind (dim )) + list (history .unbind (dim )), dim = dim )
246
+
247
+ @classmethod
248
+ def default_spec (cls , shape = (- 1 ,)):
249
+ """A default spec to use in transforms / envs that return History objects.
250
+
251
+ Args:
252
+ shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length
253
+ along time dimension).
254
+
255
+ Example:
256
+ >>> import tensordict
257
+ >>> from torchrl.data import History
258
+ >>> tensordict.set_list_to_stack(True).set()
259
+ >>>
260
+ >>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,))
261
+ >>> spec = history.default_spec()
262
+ >>> print(spec)
263
+ Composite(
264
+ role: NonTensor(
265
+ shape=torch.Size([-1]),
266
+ space=None,
267
+ device=None,
268
+ dtype=None,
269
+ domain=None,
270
+ example_data=foo),
271
+ content: NonTensor(
272
+ shape=torch.Size([-1]),
273
+ space=None,
274
+ device=None,
275
+ dtype=None,
276
+ domain=None,
277
+ example_data=foo),
278
+ device=None,
279
+ shape=torch.Size([-1]))
280
+ >>> print(spec.zero())
281
+ History(
282
+ content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
283
+ role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
284
+ batch_size=torch.Size([1]),
285
+ device=None,
286
+ is_shared=False)
287
+
288
+ """
289
+ from torchrl .data import Composite , NonTensor
290
+
291
+ def get_default_value (field ):
292
+ if field .default is not dataclasses .MISSING :
293
+ return field .default
294
+ elif field .type in (str , "str" ):
295
+ return "foo"
296
+ else :
297
+ return None
298
+
299
+ defaults = {
300
+ k : NonTensor (
301
+ example_data = get_default_value (cls .__dataclass_fields__ [k ]), shape = (- 1 ,)
302
+ )
303
+ for k in cls .__dataclass_fields__
304
+ }
305
+
306
+ return Composite (defaults , shape = shape , data_cls = cls )
0 commit comments