11import inspect
22import re
3+ import warnings
34from dataclasses import dataclass , field
45from functools import lru_cache
5- from typing import Callable , Dict , Hashable , Optional , cast
6+ from typing import Callable , Dict , Hashable , Optional , Tuple , cast
67
78from jinja2 import Environment , StrictUndefined
89
@@ -29,6 +30,8 @@ class Template:
2930 The template to render.
3031 signature
3132 The prompt function's signature.
33+ model
34+ The model the `Template` is associated with. Defaults to `None`.
3235 registry
3336 Registry that maps function names to their respective `Template`
3437 instances.
@@ -50,7 +53,7 @@ def __call__(self, *args, **kwargs) -> str:
5053 """
5154 bound_arguments = self .signature .bind (* args , ** kwargs )
5255 bound_arguments .apply_defaults ()
53- return render (self .template , ** bound_arguments .arguments )
56+ return render (self .template , self . model , ** bound_arguments .arguments )
5457
5558 def __str__ (self ):
5659 return self .template
@@ -74,6 +77,7 @@ def __getitem__(self, model_name: str):
7477 try :
7578 return self .registry [model_name ]
7679 except KeyError :
80+ self .model = model_name
7781 return self
7882
7983 def register (self , model_name : str ):
@@ -140,13 +144,21 @@ def template(fn: Callable) -> Template:
140144
141145
142146@lru_cache
143- def render (template : str , ** values : Optional [Dict [str , Hashable ]]) -> str :
147+ def render (
148+ template : str ,
149+ model_name : Optional [str ] = None ,
150+ ** values : Optional [Dict [str , Hashable ]],
151+ ) -> str :
144152 r"""Parse a Jinaj2 template and translate it into an Outlines graph.
145153
146154 This function removes extra whitespaces and linebreaks from templates to
147155 allow users to enter prompts more naturally than if they used Python's
148156 constructs directly. See the examples for a detailed explanation.
149157
158+ We also define the `bos` and `eos` special variables which, when used, will
159+ be replaced by the model's BOS and EOS tokens respectively. This allows you
160+ to write prompts that are model-agnostic.
161+
150162 Examples
151163 --------
152164
@@ -223,6 +235,8 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
223235 ----------
224236 template
225237 A string that contains a template written with the Jinja2 syntax.
238+ model_name
239+ The name of the model to which the rendered string will be passed.
226240 **values
227241 Map from the variables in the template to their value.
228242
@@ -245,12 +259,34 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
245259 # used to continue to the next line without linebreak.
246260 cleaned_template = re .sub (r"(?![\r\n])(\b\s+)" , " " , cleaned_template )
247261
262+ # Warn the user when the model is not present in the special token registry
263+ if model_name not in SPECIAL_TOKENS :
264+ warnings .warn (
265+ UserWarning (
266+ f"The model { model_name } is not present in the special token registry."
267+ "As a result, EOS and BOS tokens will be rendered as the empty string."
268+ "Please open an issue: https://github.com/outlines-dev/prompts/issues"
269+ "And ask for the model to be added to the registry."
270+ )
271+ )
272+
248273 env = Environment (
249274 trim_blocks = True ,
250275 lstrip_blocks = True ,
251276 keep_trailing_newline = True ,
252277 undefined = StrictUndefined ,
253278 )
279+ env .globals ["bos" ] = SPECIAL_TOKENS .get (model_name , ("" , "" ))[0 ]
280+ env .globals ["eos" ] = SPECIAL_TOKENS .get (model_name , ("" , "" ))[1 ]
254281 jinja_template = env .from_string (cleaned_template )
255282
256283 return jinja_template .render (** values )
284+
285+
286+ # (BOS, EOS)
287+ SPECIAL_TOKENS : Dict [Optional [str ], Tuple [str , str ]] = {
288+ None : ("" , "" ),
289+ "google/gemma-2-9b" : ("<bos>" , "<eos>" ),
290+ "openai-community/gpt2" : ("" , "<|endoftext|>" ),
291+ "mistralai/Mistral-7B-v0.1" : ("<s>" , "</s>" ),
292+ }
0 commit comments