forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhistory.py
More file actions
1378 lines (1234 loc) · 58.1 KB
/
history.py
File metadata and controls
1378 lines (1234 loc) · 58.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import dataclasses
import re
from typing import Literal, TYPE_CHECKING
import torch
from tensordict import (
lazy_stack,
LazyStackedTensorDict,
list_to_stack,
TensorClass,
TensorDict,
)
from tensordict.utils import _maybe_correct_neg_dim
from torchrl._utils import logger as torchrl_logger
if TYPE_CHECKING:
import transformers
# Global storage for custom templates and their metadata
_CHAT_TEMPLATES = {
"chatml_format": """{% for message in messages %}
{%- if message['role'] == 'assistant' %}
{% generation %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endgeneration %}
{%- else %}
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
{%- endif %}
{% endfor %}
{%- if add_generation_prompt %}
{% generation %}{{- '<|im_start|>assistant\n' }}{% endgeneration %}
{%- endif %}
""",
"qwen": """
{%- if tools %}
{{- '<|im_start|>system\\n' }}
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
{{- 'You are a helpful assistant.' }}
{%- endif %}
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
{%- for tool in tools %}
{{- "\\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}
{%- else %}
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}
{%- else %}
{{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
{%- elif (message.role == "assistant" and not message.tool_calls) %}
{% generation %} {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} {% endgeneration %}
{%- elif message.role == "assistant" %}
{% generation %}{{- '<|im_start|>' + message.role }}
{%- if message.content %}
{{- '\\n' + message.content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\\n<tool_call>\\n{\\\"name\\\": \\\"' }}
{{- tool_call.name }}
{{- '\\\", \\\"arguments\\\": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\\n' }}{% endgeneration %}
{%- elif message.role == "tool" %}
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>tool' }}
{%- endif %}
{{- '\\n<tool_response>\\n' }}
{%- if message.tool_responses %}
{{- message.tool_responses }}
{%- else %}
{{- message.content }}
{%- endif %}
{{- '\\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{% generation %}{{- '<|im_start|>assistant\\n' }}{% endgeneration %}
{%- endif %}
""",
"dialogpt": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'user' %}{{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ ' ' }}{% endgeneration %}{% endif %}""",
"falcon": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] }}{% endgeneration %}\n\n{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] }}\n\n{% elif message['role'] == 'system' %}{{ message['content'] }}\n\n{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant: ' }}{% endgeneration %}{% endif %}""",
"deepseek": """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] + eos_token }}{% endgeneration %}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant:' }}{% endgeneration %}{% endif %}""",
"llama": """{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- endif %}
{%- if system_message %}
{{- "<|header_start|>system<|header_end|>\n\n" }}
{{- system_message }}
{{- "<|eot|>" }}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'assistant' %}
{% generation %}{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'text' %}
{{- content['text'] | trim }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- "<|eot|>" }}{% endgeneration %}
{%- else %}
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'text' %}
{{- content['text'] | trim }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- "<|eot|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{% generation %}{{- '<|header_start|>assistant<|header_end|>\n\n' }}{% endgeneration %}
{%- endif %}""",
}
# Global storage for custom template metadata
_CUSTOM_INVERSE_PARSERS = {}
_CUSTOM_MODEL_FAMILY_KEYWORDS = {}
def add_chat_template(
template_name: str,
template: str,
inverse_parser: callable | None = None,
model_family_keywords: list[str] | None = None,
) -> None:
r"""Add a custom chat template to the global template dictionary.
This function allows you to add custom chat templates for new model families
that support assistant token masking via the `{% generation %}` keyword.
Args:
template_name (str): The name of the template (e.g., "llama", "mistral").
This name will be used in the `chat_template_name` parameter of
`History.apply_chat_template()` and `History.from_text()`.
template (str): The Jinja2 template string. Must include `{% generation %}`
blocks around assistant message content to enable token masking.
inverse_parser (callable, optional): A function that parses formatted text back
into a History object. Should have signature `(text: str) -> History`.
If None, a basic parser will be used.
model_family_keywords (list[str], optional): Keywords to detect this model family
in the auto-detection logic. For example, ["llama", "meta-llama"] for Llama models.
If provided, the template will be automatically selected for models containing
these keywords in their name.
Example:
>>> from torchrl.data.llm.chat import add_chat_template, History
>>> from transformers import AutoTokenizer
>>>
>>> # Add a custom template for Llama models
>>> llama_template = '''
... {% for message in messages %}
... {%- if message['role'] == 'user' %}
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
... {%- elif message['role'] == 'assistant' %}
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
... {%- endif %}
... {% endfor %}
... {%- if add_generation_prompt %}
... {% generation %}{{ ' ' }}{% endgeneration %}
... {%- endif %}
... '''
>>>
>>> def parse_llama_text(text: str) -> History:
... # Custom parser for Llama format
... import re
... pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
... matches = re.findall(pattern, text, re.DOTALL)
... messages = []
... for user_content, assistant_content in matches:
... messages.append(History(role="user", content=user_content.strip()))
... messages.append(History(role="assistant", content=assistant_content.strip()))
... return lazy_stack(messages)
>>>
>>> # Add the template with auto-detection
>>> add_chat_template(
... template_name="llama",
... template=llama_template,
... inverse_parser=parse_llama_text,
... model_family_keywords=["llama", "meta-llama"]
... )
>>>
>>> # Now you can use it with auto-detection
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> history = History.from_chats([[
... {"role": "user", "content": "Hello"},
... {"role": "assistant", "content": "Hi there!"}
... ]])
>>>
>>> # Auto-detection will use the llama template
>>> result = history.apply_chat_template(
... tokenizer=tokenizer,
... add_generation_prompt=False,
... return_dict=True,
... return_assistant_tokens_mask=True,
... )
>>>
>>> # Or use it explicitly
>>> result = history.apply_chat_template(
... tokenizer=tokenizer,
... chat_template_name="llama",
... add_generation_prompt=False,
... return_dict=True,
... return_assistant_tokens_mask=True,
... )
.. note:
- The template must include `{% generation %}` blocks around assistant message
content to enable assistant token masking.
- The inverse parser should handle the specific format of your template.
- Model family keywords are case-insensitive and matched against the tokenizer's
`name_or_path` attribute.
- Templates are stored globally and persist for the duration of the Python session.
"""
global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS # noqa: F824
# Validate template contains generation blocks
if "{% generation %}" not in template:
raise ValueError(
f"Template '{template_name}' must include '{{% generation %}}' blocks "
"around assistant message content to enable token masking."
)
# Add template to dictionary
_CHAT_TEMPLATES[template_name] = template
# Store inverse parser if provided
if inverse_parser is not None:
_CUSTOM_INVERSE_PARSERS[template_name] = inverse_parser
# Store model family keywords if provided
if model_family_keywords is not None:
_CUSTOM_MODEL_FAMILY_KEYWORDS[template_name] = model_family_keywords
torchrl_logger.info(
f"Added custom chat template '{template_name}' with assistant token masking support"
)
# We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
class ContentBase(TensorClass["nocast", "shadow"]):
"""Base class for all message content types.
Attributes:
type (str): The type of the content.
text (str, optional): The text content.
url (str, optional): The URL content.
data (str, optional): The data content.
mime_type (str, optional): The MIME type of the content.
name (str, optional): The name of the content.
size (int, optional): The size of the content.
function_name (str, optional): The name of the function.
function_args (dict, optional): The arguments of the function.
Examples:
>>> from tensordict import lazy_stack
>>> content1 = ContentBase(type="text", text="Hello, world!")
>>> print(content1)
ContentBase(
text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None),
type=NonTensorData(data=text, batch_size=torch.Size([]), device=None),
url=None,
data=None,
mime_type=None,
name=None,
size=None,
function_name=None,
function_args=None,
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> content2 = ContentBase(type="image", url="https://example.com/image.jpg")
>>> print(content2)
ContentBase(
type=NonTensorData(data=image, batch_size=torch.Size([]), device=None),
url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None),
text=None,
data=None,
mime_type=None,
name=None,
size=None,
function_name=None,
function_args=None,
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> content = lazy_stack([content1, content2])
>>> print(content)
ContentBase(
type=NonTensorStack(
['text', 'image'],
batch_size=torch.Size([2]),
device=None),
url=None,
data=None,
mime_type=None,
name=None,
size=None,
function_name=None,
function_args=None,
text=None,
batch_size=torch.Size([2]),
device=None,
is_shared=False)
>>> # A content is typically used in a History object. Usually, its batch dimension is
>>> # one dimension greater than the History object.
>>> history = History(role="user", content=content)
"""
type: Literal[
"text", "image", "audio", "video", "file", "function_call"
] # Required: "text", "image", "audio", "video", "file", "function_call"
# Text content
text: str | None = None
# Media/file content (either URL or data)
url: str | None = None # HTTP URL to content
data: str | None = None # Base64 encoded content
# Metadata
mime_type: str | None = None # "image/jpeg", "audio/mp3", "application/pdf"
name: str | None = None # Original filename or description
size: int | None = None # File size in bytes
# Function calling (for AI agents)
function_name: str | None = None
function_args: dict | None = None
class History(TensorClass["nocast"]):
"""A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models.
The `History` class provides a centralized API for managing conversational data, offering several advantages over
traditional list-based approaches:
- Centralized API for conversion to and from string formats, facilitating seamless integration with language models.
- Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation
trajectories, especially useful in reinforcement learning environments.
- Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data.
- **Assistant token masking support** across multiple model families for reinforcement learning applications.
**Recent Changes:**
- **ChatHistory Integration**: History objects are now used within :class:`~torchrl.modules.llm.policies.ChatHistory`
containers for structured conversation management in LLM environments.
- **Modular Wrapper Support**: Both vLLMWrapper and TransformersWrapper now use History objects when `input_mode="history"`
is specified, providing consistent conversation state management.
- **Environment Integration**: ChatEnv and related environments use History objects for state management and conversation tracking.
.. note:: The `"<none>"` role is used to indicate that the element is a placeholder,
for example when the tool call was not executed but a stack requires a certain number of elements
per batch to have congruent shapes. The :meth:`~torchrl.data.llm.chat.History.apply_chat_template`
method will remove the `<none>` role from the history.
**Assistant Token Masking Support:**
The class supports assistant token masking across multiple model families, allowing you to identify which tokens
in a conversation were generated by the assistant. This is crucial for reinforcement learning applications.
**Supported Model Families:**
- **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support
- **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format
- **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format
- **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format
- **Other models** (OPT, GPT, MPT, BLOOM, Pythia, Phi, etc.): Default `chatml_format` template
**Example with Assistant Token Masking:**
.. code-block:: python
>>> from torchrl.data.llm.chat import History
>>> from torchrl.modules.llm.policies import ChatHistory
>>> from transformers import AutoTokenizer
>>>
>>> # Create a conversation history
>>> history = History.from_chats([[
... {"role": "user", "content": "Hello"},
... {"role": "assistant", "content": "Hi there!"},
... {"role": "user", "content": "How are you?"},
... {"role": "assistant", "content": "I'm doing well, thanks!"}
... ]])
>>>
>>> # Create ChatHistory container for LLM wrapper
>>> chat_history = ChatHistory(prompt=history)
>>>
>>> # Load any supported tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>>
>>> # Apply chat template with assistant token masking
>>> result = history.apply_chat_template(
... tokenizer=tokenizer,
... add_generation_prompt=False,
... return_dict=True,
... return_assistant_tokens_mask=True,
... )
>>>
>>> # The result contains an assistant_masks tensor
>>> assistant_masks = result["assistant_masks"]
>>> print(f"Assistant tokens: {assistant_masks.sum().item()}")
**Integration with LLM Wrappers:**
History objects work seamlessly with the new modular wrapper design:
.. code-block:: python
>>> from torchrl.modules.llm import TransformersWrapper
>>> from torchrl.modules.llm.policies import ChatHistory
>>>
>>> # Create wrapper with history input mode
>>> wrapper = TransformersWrapper(
... model, tokenizer=tokenizer,
... input_mode="history",
... generate=True,
... return_log_probs=True
... )
>>>
>>> # Use History with ChatHistory container
>>> history = History.from_chats([[
... {"role": "user", "content": "Hello"},
... {"role": "assistant", "content": "Hi there!"}
... ]])
>>> chat_history = ChatHistory(prompt=history)
>>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
>>> print(result["history"].response) # New response from LLM
Attributes:
role (str): The role of the message sender.
content (str): The content of the message.
is_complete (bool): Whether the message was properly terminated with an end token. Defaults to `True`.
tool_calls (list[dict] | None): Optional list of tool calls in the message.
tool_responses (list[str] | None): Optional list of tool responses.
Methods:
apply_chat_template: converts the `History` object to str / tokens.
append: append one element to the list of items along a given dimension.
extend: extend the list of items along a given dimension.
Examples:
>>> # With tensordict < 0.10, we need to tell the lib that lists constitute batches
>>> import tensordict
>>> tensordict.set_list_to_stack(True).set()
>>> import transformers
>>> history0 = History(
... role='system',
... content='''CONTENT
... This is the setup''',
... )
>>> history1 = History(
... role='user',
... content='''CONTENT
... This is the first user prompt''',
... )
>>> history2 = History(
... role='assistant',
... content='''CONTENT
... This is the second prompt, the first for the assistant.''',
... )
>>> history = torch.stack([history0, history1, history2])
>>> assert history.role == ['system', 'user', 'assistant']
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("GPT2")
>>> # Apply a template to pass the history to an LLM. Note that the output has
>>> # an additional prompt to elict an answer from the LLM thanks to the 'add_generation_prompt' argument.
>>> parsed_string = history.apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True)
>>> parsed_string
<|im_start|>system
CONTENT
This is the setup<|im_end|>
<|im_start|>user
CONTENT
This is the first user prompt<|im_end|>
<|im_start|>assistant
CONTENT
This is the second prompt, the first for the assistant.<|im_end|>
<|im_start|>assistant
.. seealso::
:class:`~torchrl.modules.llm.policies.ChatHistory`: Container for managing conversation data in LLM environments.
:class:`~torchrl.modules.llm.policies.Text`: Container for text data.
:class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
"""
role: str | list[str] | list[list[str]]
content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[
list[ContentBase]
]
is_complete: bool = True
tool_calls: list[dict] | None = None
tool_responses: list[str] | None = None
def __post_init__(self):
if not list_to_stack():
raise RuntimeError(
"Please set the list_to_stack to True using tensordict.set_list_to_stack(True).set() at the beginning of your script, "
"or the LIST_TO_STACK=1 environment variable."
)
def apply_chat_template(
self,
*,
tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
add_generation_prompt: bool = True,
chat_template: str | None = None,
chat_template_name: str | None = None,
continue_final_message: bool = False,
tokenize: bool | None = None,
padding: bool | str = False,
truncation: bool | str = False,
return_tensors: str | None = None,
return_dict: bool | None = None,
return_assistant_tokens_mask: bool = False,
**kwargs,
) -> str | list[str] | TensorDict:
"""Applies a chat template to the history.
Keyword Args:
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
chat_template_name (str, optional): The name of the chat template to use.
Prevalent over `tokenizer.chat_template`. If `None`, the method will automatically detect the model family and use the appropriate template.
Defaults to `None`.
continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`.
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
If `True`, the mask will be written to the `assistant_masks` key.
For tokens generated by the assistant, the mask will contain `1`.
For user and system tokens, the mask will contain `0`.
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
Defaults to `False`.
.. note:: Assistant token masking is supported across multiple model families:
- **Qwen family**: Uses custom template with full tool calling support
- **DialoGPT family**: Uses custom template for conversation format
- **Falcon family**: Uses custom template for instruction format
- **DeepSeek family**: Uses custom template with native format
- **Other models**: Use the default `chatml_format` template
The method automatically detects the model family and selects the appropriate template.
**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
Returns:
The formatted history.
"""
if chat_template is None:
if chat_template_name is not None:
chat_template = _CHAT_TEMPLATES[chat_template_name]
chat_template_name = None
elif tokenizer is None:
raise RuntimeError(
"You must specify a tokenizer to use when chat_template is not specified."
)
else:
# Auto-detect model family and use appropriate template
model_name = getattr(tokenizer, "name_or_path", "").lower()
# First check for custom model family keywords
custom_template_found = False
for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
if any(keyword.lower() in model_name for keyword in keywords):
chat_template = _CHAT_TEMPLATES[template_name]
chat_template_name = None
custom_template_found = True
break
if not custom_template_found:
# Fall back to built-in model family detection
if "qwen" in model_name:
# We prefer our implementation of the Qwen template,
# since it accounts for the assistant's masking.
chat_template = _CHAT_TEMPLATES["qwen"]
chat_template_name = None
elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
# DialoGPT family - use our custom template
chat_template = _CHAT_TEMPLATES["dialogpt"]
chat_template_name = None
elif "falcon" in model_name or "tiiuae/falcon" in model_name:
# Falcon family - use our custom template
chat_template = _CHAT_TEMPLATES["falcon"]
chat_template_name = None
elif "deepseek" in model_name:
# DeepSeek family - use our custom template with generation keyword
chat_template = _CHAT_TEMPLATES["deepseek"]
chat_template_name = None
elif "llama" in model_name:
# Llama family - use our custom template
chat_template = _CHAT_TEMPLATES["llama"]
chat_template_name = None
else:
# For other models, check if their default template supports generation
if (
hasattr(tokenizer, "chat_template")
and tokenizer.chat_template
and "{% generation %}" in tokenizer.chat_template
):
# Use the model's own template if it supports generation
chat_template = tokenizer.chat_template
else:
# Use our default chatml_format template
chat_template = _CHAT_TEMPLATES["chatml_format"]
if chat_template is None:
chat_template = _CHAT_TEMPLATES["chatml_format"]
if tokenize is None:
if return_assistant_tokens_mask or return_tensors is not None:
tokenize = True
else:
tokenize = False
if tokenize:
if return_tensors is None:
return_tensors = "pt"
if return_dict is None and return_assistant_tokens_mask:
return_dict = True
elif return_dict is None:
return_dict = False
if self.ndim > 1:
result = [
self[i].apply_chat_template(
tokenizer=tokenizer,
add_generation_prompt=add_generation_prompt,
chat_template=chat_template,
chat_template_name=chat_template_name,
tokenize=tokenize,
padding=padding,
truncation=truncation,
return_tensors=return_tensors,
continue_final_message=continue_final_message,
return_dict=return_dict,
return_assistant_tokens_mask=return_assistant_tokens_mask,
**kwargs,
)
for i in range(self.batch_size[0])
]
if return_dict:
return lazy_stack(result)
else:
return result
self_flat = self.view(-1)
# tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
self_flat = self_flat.tolist(tolist_first=True)
# Remove the "<none>" role
self_flat = [item for item in self_flat if item["role"] != "<none>"]
result = tokenizer.apply_chat_template(
conversation=self_flat,
add_generation_prompt=add_generation_prompt,
chat_template=chat_template,
tokenize=tokenize,
padding=padding,
truncation=truncation,
return_tensors=return_tensors,
continue_final_message=continue_final_message,
return_dict=return_dict,
return_assistant_tokens_mask=return_assistant_tokens_mask,
**kwargs,
)
if not isinstance(result, (torch.Tensor, list, str)):
result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
# If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
if self.batch_dims == 1:
if result.batch_size[0] != 1:
raise RuntimeError(
f"Expected a batch size of 1, got {result.batch_size[0]}."
)
result = result.squeeze(0)
return result
@classmethod
def from_text(
cls,
text: str | list[str],
chat_template_name: str | None = None,
# currently without effect
chat_template: str | None = None,
tokenizer: transformers.AutoTokenizer # noqa: F821
| transformers.AutoProcessor # noqa: F821
| None = None,
) -> History:
r"""Inverts a chat template into a History object.
Args:
text (str | list[str]): The chat template to invert.
chat_template_name (str, optional): The name of the chat template to use.
tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use.
Returns:
History: The inverted History object.
Examples:
>>> from torchrl.data.llm.history import History
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\n<think>The capital of France is Paris, the capital of Germany is Berlin.</think>\n<answer><python>\n"
>>> history = History.from_text(text, tokenizer=tokenizer)
>>> print(history)
History(
content=NonTensorStack(
['You are a helpful assistant.', 'Write a python s...,
batch_size=torch.Size([3]),
device=None),
is_complete=NonTensorStack(
[True, True, False],
batch_size=torch.Size([3]),
device=None),
role=NonTensorStack(
['system', 'user', 'assistant'],
batch_size=torch.Size([3]),
device=None),
tool_calls=None,
tool_responses=None,
batch_size=torch.Size([3]),
device=None,
is_shared=False)
"""
if chat_template_name is None:
if chat_template is not None:
# TODO: find best match given template
pass
model_name = getattr(tokenizer, "name_or_path", "").lower()
# First check for custom model family keywords
custom_template_found = False
for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
if any(keyword.lower() in model_name for keyword in keywords):
chat_template_name = template_name
custom_template_found = True
break
if not custom_template_found:
# Fall back to built-in model family detection
if "qwen" in model_name:
# We can automatically detect the template name from the tokenizer
# and use the precoded parser.
chat_template_name = "qwen"
elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
chat_template_name = "dialogpt"
elif "falcon" in model_name or "tiiuae/falcon" in model_name:
chat_template_name = "falcon"
elif "deepseek" in model_name:
chat_template_name = "deepseek"
elif "llama" in model_name:
chat_template_name = "llama"
else:
chat_template_name = "chatml_format"
# Get the appropriate inverse parser function
if chat_template_name in ("chatml_format",):
func = cls._inv_chatml
elif chat_template_name in ("qwen",):
func = cls._inv_qwen
elif chat_template_name in ("dialogpt",):
func = cls._inv_dialogpt
elif chat_template_name in ("falcon",):
func = cls._inv_falcon
elif chat_template_name in ("deepseek",):
func = cls._inv_deepseek
elif chat_template_name in ("llama",):
func = cls._inv_llama
elif chat_template_name in _CUSTOM_INVERSE_PARSERS:
# Use custom inverse parser
func = _CUSTOM_INVERSE_PARSERS[chat_template_name]
else:
raise NotImplementedError(
f"chat_template_name '{chat_template_name}' is not supported. "
"Supported templates: 'chatml_format', 'qwen', 'dialogpt', 'falcon', 'deepseek'. "
"Use add_chat_template() to add custom templates."
)
if isinstance(text, list):
list_of_histories = [func(t) for t in text]
try:
return lazy_stack(list_of_histories)
except RuntimeError as e:
raise RuntimeError(
f"Failed to stack histories: {list_of_histories=}"
) from e
return func(text)
@classmethod
def _inv_chatml(cls, text: str) -> History:
"""Inverts a chatml string into a History object.
Args:
text (str): The chatml string to invert.
Returns:
History: The inverted History object.
"""
import json
torchrl_logger.debug(f"Inverting chatml:\n{text}")
# Find all complete blocks (ending with im_end or endoftext)
complete_pattern = r"<\|im_start\|>(.*?)\n(.*?)<\|(im_end|endoftext)\|>"
complete_matches = re.findall(complete_pattern, text, flags=re.DOTALL)
# Find any incomplete block at the end
incomplete_pattern = r"<\|im_start\|>(.*?)\n(.*?)$"
incomplete_matches = []
if complete_matches:
# Look for incomplete block after the last complete one
last_complete = complete_matches[-1]
last_complete_text = f"<|im_start|>{last_complete[0]}\n{last_complete[1]}<|{last_complete[2]}|>"
remaining_text = text[
text.rindex(last_complete_text) + len(last_complete_text) :
]
if remaining_text.strip():
incomplete_match = re.search(
incomplete_pattern, remaining_text, flags=re.DOTALL
)
if incomplete_match:
incomplete_matches = [
(incomplete_match.group(1), incomplete_match.group(2), None)
]
else:
# No complete blocks, check entire text for incomplete block
incomplete_match = re.search(incomplete_pattern, text, flags=re.DOTALL)
if incomplete_match:
incomplete_matches = [
(incomplete_match.group(1), incomplete_match.group(2), None)
]
# Combine complete and incomplete matches
matches = complete_matches + incomplete_matches
# Define tool patterns - same as Qwen for consistency
tool_call_pattern = re.compile(r"<tool_call>\n(.*?)\n</tool_call>", re.DOTALL)
tool_response_pattern = re.compile(
r"<tool_response>\n(.*?)\n</tool_response>", re.DOTALL
)
parsed_messages = []
for match in matches:
role = match[0].strip()
content = match[1].strip()
is_complete = match[2] is not None # None indicates incomplete
# Initialize message dict
message_dict = {
"role": role,
"content": content,
"is_complete": is_complete,
"tool_calls": None,
"tool_responses": None,
}
# Find tool calls within the message
tool_calls = tool_call_pattern.findall(content)
if tool_calls:
tool_calls_list = []
for tool_call in tool_calls:
try:
tool_call_dict = json.loads(tool_call)
tool_calls_list.append(tool_call_dict)
except json.JSONDecodeError:
continue
if tool_calls_list:
message_dict["tool_calls"] = tool_calls_list
# Check for tool responses
tool_responses = tool_response_pattern.findall(content)
if tool_responses:
message_dict["tool_responses"] = tool_responses
parsed_messages.append(cls(**message_dict))
if not parsed_messages:
raise RuntimeError(
f"Couldn't get a single item out of text {text}. A common cause "
f"if that special tokens should not be ommitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
)
return lazy_stack(parsed_messages)
@classmethod
def _inv_qwen(cls, template):
import json
# Define regex patterns for different parts of the template
message_pattern = re.compile(
r"<\|im_start\|>(.*?)(?:<\|(im_end|endoftext)\|>|$)", re.DOTALL
)
tool_call_pattern = re.compile(r"<tool_call>\n(.*?)\n</tool_call>", re.DOTALL)
tool_response_pattern = re.compile(
r"<tool_response>\n(.*?)\n</tool_response>", re.DOTALL
)
# Find all messages and track if they end with a proper token
messages = []
is_complete_list = []
for match in message_pattern.finditer(template):
full_match = match.group(0)
messages.append(match.group(1))
# Check if the message ends with a proper token
is_complete_list.append(
full_match.endswith("<|im_end|>")
or full_match.endswith("<|endoftext|>")
)
parsed_messages = []
for message, is_complete in zip(messages, is_complete_list):
# Split the message into role and content
parts = message.split("\n", 1)
if len(parts) < 2:
continue
role, content = parts[0], parts[1]
# Initialize message dict
message_dict = {
"role": role.strip(),
"content": content.strip(),
"is_complete": is_complete,
"tool_calls": None,
"tool_responses": None,
}
# Find tool calls within the message
tool_calls = tool_call_pattern.findall(content)
if tool_calls:
tool_calls_list = []
for tool_call in tool_calls:
try:
tool_call_dict = json.loads(tool_call)
tool_calls_list.append(tool_call_dict)
except json.JSONDecodeError:
continue
if tool_calls_list:
message_dict["tool_calls"] = tool_calls_list
# Check for tool responses
tool_responses = tool_response_pattern.findall(content)
if tool_responses:
message_dict["tool_responses"] = tool_responses
parsed_messages.append(cls(**message_dict))
if not parsed_messages:
raise RuntimeError(
f"Couldn't get a single item out of text {template}. A common cause "
f"if that special tokens should not be ommitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
)
return lazy_stack(parsed_messages)
@classmethod
def _inv_dialogpt(cls, text: str) -> History:
"""Inverts a DialogPT string into a History object.
Args:
text (str): The DialogPT string to invert.
Returns:
History: The inverted History object.
"""
torchrl_logger.debug(f"Inverting DialogPT:\n{text}")
# DialogPT format is simple: alternating user/assistant messages
# Split by lines and parse
lines = text.strip().split("\n")