From 815b465368b93f68fd550613f90e6a4687be80b5 Mon Sep 17 00:00:00 2001 From: Xavier Fernandes Date: Wed, 1 May 2024 16:52:04 -0700 Subject: [PATCH] Add TransformersChat code to figure out correct role start and end tokens --- guidance/models/transformers/_transformers.py | 58 ++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/guidance/models/transformers/_transformers.py b/guidance/models/transformers/_transformers.py index 69173ab6d..9b94cb8e3 100644 --- a/guidance/models/transformers/_transformers.py +++ b/guidance/models/transformers/_transformers.py @@ -1,5 +1,7 @@ import os import re +import uuid +import jinja2 try: import torch @@ -277,4 +279,58 @@ def __init__( class TransformersChat(Transformers, Chat): - pass + + def __init__(self, *args, chat_template=None, **kwargs): + super().__init__(*args, **kwargs) + + self._fake_content = str(uuid.uuid4()) + + + def get_role_start(self, role_name, **kwargs): + """The starting grammar for a role. + + By default we follow the GPT role tag start conventions. + + Parameters + ---------- + role_name : str + The name of the role, like "user", or "assistant" + kwargs : dict + This kwargs are added to the role start as arguments. + """ + if self.engine.tokenizer._orig_tokenizer.chat_template is not None or self.engine.tokenizer._orig_tokenizer.defaut_chat_template is not None: + messages = [ + {"role": role_name, "content": self._fake_content} + ] + sereialized_messages = self.engine.tokenizer._orig_tokenizer.apply_chat_template(messages, tokenize=False) + start = sereialized_messages.find(self._fake_content) + return sereialized_messages[:start] + else: + return ( + "<|im_start|>" + + role_name + + "".join([f' {k}="{v}"' for k, v in kwargs.items()]) + + "\n" + ) + + def get_role_end(self, role_name=None): + """The ending bytes for a role. + + Note that we cannot use a grammar in closers because they need to remain constant + so we can append them whenever we need a representation before the final closing of the context. + By default we follow the GPT role tag end conventions. + + Parameters + ---------- + role_name : str + The name of the role, like "user", or "assistant" + """ + if self.engine.tokenizer._orig_tokenizer.chat_template is not None or self.engine.tokenizer._orig_tokenizer.defaut_chat_template is not None: + messages = [ + {"role": role_name, "content": self._fake_content} + ] + sereialized_messages = sereialized_messages = self.engine.tokenizer._orig_tokenizer.apply_chat_template(messages, tokenize=False) + end = sereialized_messages.find(self._fake_content) + len(self._fake_content) + return sereialized_messages[end:] + else: + return "<|im_end|>" \ No newline at end of file