-
Notifications
You must be signed in to change notification settings - Fork 678
[WIP] Proper tool calling support in the torchtune #2794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
import json | ||
|
||
from typing import Any, Optional | ||
from typing import Any, Optional, Mapping | ||
|
||
import jinja2 | ||
from jinja2 import StrictUndefined | ||
|
@@ -90,8 +90,11 @@ def _infer_bos_eos_tokens(self): | |
self.eos_token = "<eos>" | ||
|
||
if self.config: | ||
self.bos_token = self._get_token_from_config(self.config, "bos_token") | ||
self.eos_token = self._get_token_from_config(self.config, "eos_token") | ||
try: | ||
self.bos_token = self._get_token_from_config(self.config, "bos_token") | ||
self.eos_token = self._get_token_from_config(self.config, "eos_token") | ||
except ValueError: | ||
pass | ||
|
||
if self.bos_token is not None: | ||
self.bos_id = self.tokenizer.token_to_id(self.bos_token) | ||
if self.eos_token is not None: | ||
|
@@ -103,9 +106,6 @@ def _infer_bos_eos_tokens(self): | |
if self.eos_id is None: | ||
self.eos_id = self.generation_config.get("eos_token_id") | ||
|
||
if self.bos_id is None or self.eos_id is None: | ||
raise ValueError("Could not infer BOS and EOS token IDs from config") | ||
|
||
def _infer_should_add_bos_eos(self): | ||
""" | ||
Hugging Face tokenizers sometimes add BOS by default. We should infer this to determine | ||
|
@@ -136,9 +136,16 @@ def encode( | |
list[int]: The list of token ids. | ||
""" | ||
token_ids = self.tokenizer.encode(text).ids | ||
if add_bos and not self.hf_adds_bos and self.bos_token not in text: | ||
|
||
# Both bos_id and eos_id might be None (null). Therefore, we need an additional check. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this related to tool-calling? Or a separate issue? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is caused by separate issue in HuggingfaceBaseTokenizer. |
||
if ( | ||
add_bos | ||
and not self.hf_adds_bos | ||
and self.bos_token not in text | ||
and self.bos_id | ||
): | ||
token_ids.insert(0, self.bos_id) | ||
if add_eos and not self.hf_adds_eos: | ||
if add_eos and not self.hf_adds_eos and self.eos_id: | ||
token_ids.append(self.eos_id) | ||
return token_ids | ||
|
||
|
@@ -221,13 +228,15 @@ def __init__( | |
*, | ||
tokenizer_config_json_path: Optional[str] = None, | ||
generation_config_path: Optional[str] = None, | ||
max_seq_len: Optional[int] = None, | ||
truncation_type: str = "right", | ||
): | ||
self.base_tokenizer = HuggingFaceBaseTokenizer( | ||
tokenizer_json_path=tokenizer_json_path, | ||
tokenizer_config_json_path=tokenizer_config_json_path, | ||
generation_config_path=generation_config_path, | ||
) | ||
self.max_seq_len = max_seq_len | ||
|
||
# Contents of the tokenizer_config.json | ||
config = self.base_tokenizer.config | ||
|
@@ -272,15 +281,18 @@ def tokenize_messages( | |
self, | ||
messages: list[Message], | ||
add_eos: bool = True, | ||
max_seq_len: Optional[int] = None, | ||
) -> tuple[list[int], list[bool]]: | ||
tokenized_messages = [] | ||
mask = [] | ||
previous_tokens = [] | ||
|
||
for i, message in enumerate(messages): | ||
current_messages = [ | ||
{"role": m.role, "content": m.content[0]["content"]} | ||
{ | ||
"role": m.role, | ||
"content": m.content[0]["content"], | ||
"tool_calls": m.tool_calls, | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had some issues specifically with the LLaMA 3(.3) tokenizer here, which didn't play nicely with empty tool calls current_messages = [
{
"role": m.role,
"content": m.content[0]["content"],
**({"tool_calls": m.tool_calls} if m.tool_calls is not None else {})
}
for m in messages[: i + 1]
] I'm not sure if this is the correct logic though (or if this plays nicely with all tokenizers) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works fine with other tokenizers! Thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me this broke qwen coder (sigh). I think removing strict undefined may have fixed it? I don't recall exactly. I kept trying components from the transformers jinja compile until I found something, but I don't know much about how jinja works - there may be a better solution. |
||
for m in messages[: i + 1] | ||
] | ||
|
||
|
@@ -310,16 +322,26 @@ def tokenize_messages( | |
# Finally, truncate if necessary | ||
tokenized_messages = truncate( | ||
tokens=tokenized_messages, | ||
max_seq_len=max_seq_len, | ||
max_seq_len=self.max_seq_len, | ||
eos_id=self.base_tokenizer.eos_id, | ||
truncation_type=self.truncation_type, | ||
) | ||
|
||
mask = truncate( | ||
tokens=mask, | ||
max_seq_len=max_seq_len, | ||
max_seq_len=self.max_seq_len, | ||
eos_id=True if add_eos else None, | ||
truncation_type=self.truncation_type, | ||
) | ||
|
||
return tokenized_messages, mask | ||
|
||
def __call__(self, sample: Mapping[str, Any], inference: bool = False) -> Mapping[str, Any]: | ||
""" | ||
Apply ``tokenize_messages`` to the "messages" field in the sample. | ||
""" | ||
messages = sample.pop("messages") | ||
tokens, mask = self.tokenize_messages(messages, add_eos=not inference) | ||
sample["tokens"] = tokens | ||
sample["mask"] = mask | ||
return sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also update the role "ipython" to "tool" to match what's done by Hugging Face?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch, that argument seemed to me weird.