Skip to content

Commit 23764af

Browse files
SdeeRKJonathans575
andauthored
enhance apply_chat_template (#2513)
Co-authored-by: Jonathans575 <[email protected]>
1 parent 495645b commit 23764af

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

paddleformers/transformers/auto/tokenizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
resolve_trust_remote_code,
2525
)
2626
from transformers.modeling_gguf_pytorch_utils import load_gguf_checkpoint
27-
from transformers.models import EncoderDecoderConfig
2827
from transformers.models.auto.configuration_auto import (
2928
config_class_to_model_type,
3029
replace_list_option_in_docstrings,
@@ -35,6 +34,9 @@
3534
get_tokenizer_config,
3635
tokenizer_class_from_name,
3736
)
37+
from transformers.models.encoder_decoder.configuration_encoder_decoder import (
38+
EncoderDecoderConfig,
39+
)
3840
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
3941
from transformers.utils import cached_file
4042

paddleformers/transformers/tokenizer_utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import re
2121
from functools import wraps
22-
from typing import Any, Dict, List, Union
22+
from typing import Any, Dict, List, Optional, Union
2323

2424
from transformers import BatchEncoding
2525
from transformers.tokenization_utils import (
@@ -156,6 +156,61 @@ def wrapper(*args, **kwargs):
156156

157157
setattr(self, method_name, wrapper)
158158

159+
def apply_chat_template(
160+
self,
161+
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]], dict[str, Any]],
162+
chat_template: Optional[str] = None,
163+
**kwargs,
164+
):
165+
"""Applies chat template to conversation data (supports 3 formats):
166+
167+
1. Standard chat format:
168+
[
169+
{"role": "user", "content": "Hello"},
170+
{"role": "assistant", "content": "Hi! How can I help?"}
171+
]
172+
173+
2. Batch Conversation Format:
174+
[
175+
[{"role": "user", "content": "user messages"}, {"role": "assistant", "content": "assistant messages"}],
176+
[{"role": "user", "content": "user messages"}]
177+
]
178+
179+
3. Enhanced dictionary format (not natively supported by HuggingFace):
180+
{
181+
"messages": [
182+
{"role": "user", "content": "Query"},
183+
{"role": "assistant", "content": "Response"}
184+
],
185+
"tools": [], # Function call definitions
186+
"documents": [] # RAG context documents
187+
}
188+
"""
189+
if isinstance(conversation, dict):
190+
messages = conversation.get("messages", None)
191+
tools = conversation.get("tools", None)
192+
documents = conversation.get("documents", None)
193+
194+
# Allow kwargs override for empty values
195+
if not tools and "tools" in kwargs:
196+
tools = kwargs.pop("tools")
197+
if not documents and "documents" in kwargs:
198+
documents = kwargs.pop("documents")
199+
200+
return super().apply_chat_template(
201+
conversation=messages,
202+
chat_template=chat_template,
203+
tools=tools,
204+
documents=documents,
205+
**kwargs,
206+
)
207+
else:
208+
return super().apply_chat_template(
209+
conversation=conversation,
210+
chat_template=chat_template,
211+
**kwargs,
212+
)
213+
159214
# Rewrite hf's tokenizer function from_pretrained
160215
@classmethod
161216
def from_pretrained(

0 commit comments

Comments
 (0)