|
19 | 19 | import os
|
20 | 20 | import re
|
21 | 21 | from functools import wraps
|
22 |
| -from typing import Any, Dict, List, Union |
| 22 | +from typing import Any, Dict, List, Optional, Union |
23 | 23 |
|
24 | 24 | from transformers import BatchEncoding
|
25 | 25 | from transformers.tokenization_utils import (
|
@@ -156,6 +156,61 @@ def wrapper(*args, **kwargs):
|
156 | 156 |
|
157 | 157 | setattr(self, method_name, wrapper)
|
158 | 158 |
|
| 159 | + def apply_chat_template( |
| 160 | + self, |
| 161 | + conversation: Union[list[dict[str, str]], list[list[dict[str, str]]], dict[str, Any]], |
| 162 | + chat_template: Optional[str] = None, |
| 163 | + **kwargs, |
| 164 | + ): |
| 165 | + """Applies chat template to conversation data (supports 3 formats): |
| 166 | +
|
| 167 | + 1. Standard chat format: |
| 168 | + [ |
| 169 | + {"role": "user", "content": "Hello"}, |
| 170 | + {"role": "assistant", "content": "Hi! How can I help?"} |
| 171 | + ] |
| 172 | +
|
| 173 | + 2. Batch Conversation Format: |
| 174 | + [ |
| 175 | + [{"role": "user", "content": "user messages"}, {"role": "assistant", "content": "assistant messages"}], |
| 176 | + [{"role": "user", "content": "user messages"}] |
| 177 | + ] |
| 178 | +
|
| 179 | + 3. Enhanced dictionary format (not natively supported by HuggingFace): |
| 180 | + { |
| 181 | + "messages": [ |
| 182 | + {"role": "user", "content": "Query"}, |
| 183 | + {"role": "assistant", "content": "Response"} |
| 184 | + ], |
| 185 | + "tools": [], # Function call definitions |
| 186 | + "documents": [] # RAG context documents |
| 187 | + } |
| 188 | + """ |
| 189 | + if isinstance(conversation, dict): |
| 190 | + messages = conversation.get("messages", None) |
| 191 | + tools = conversation.get("tools", None) |
| 192 | + documents = conversation.get("documents", None) |
| 193 | + |
| 194 | + # Allow kwargs override for empty values |
| 195 | + if not tools and "tools" in kwargs: |
| 196 | + tools = kwargs.pop("tools") |
| 197 | + if not documents and "documents" in kwargs: |
| 198 | + documents = kwargs.pop("documents") |
| 199 | + |
| 200 | + return super().apply_chat_template( |
| 201 | + conversation=messages, |
| 202 | + chat_template=chat_template, |
| 203 | + tools=tools, |
| 204 | + documents=documents, |
| 205 | + **kwargs, |
| 206 | + ) |
| 207 | + else: |
| 208 | + return super().apply_chat_template( |
| 209 | + conversation=conversation, |
| 210 | + chat_template=chat_template, |
| 211 | + **kwargs, |
| 212 | + ) |
| 213 | + |
159 | 214 | # Rewrite hf's tokenizer function from_pretrained
|
160 | 215 | @classmethod
|
161 | 216 | def from_pretrained(
|
|
0 commit comments