Skip to content

Commit 4be9b25

Browse files
committed
Add model_name property to templates
When rendering templates with special variables representing, e.g., BOS tokens we need to know which model the prompt is going to be passed to so as to render the template with the appropriate token. This PR adds a `model_name` attribute to the `Template` class.
1 parent 4edcf1f commit 4be9b25

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

prompts/templates.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Template:
3737

3838
template: str
3939
signature: inspect.Signature
40+
model: Optional[str] = None
4041
registry: Dict[str, Callable] = field(default_factory=dict)
4142

4243
def __call__(self, *args, **kwargs) -> str:
@@ -83,6 +84,7 @@ def register(self, model_name: str):
8384

8485
def wrapper(fn: Callable):
8586
tpl = template(fn)
87+
tpl.model = model_name
8688
self.registry[model_name] = tpl
8789
return tpl
8890

tests/test_templates.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ def simple_prompt_name(query: str):
197197
assert callable(simple_prompt)
198198
assert callable(simple_prompt["provider/name"])
199199

200+
assert simple_prompt.model is None
201+
assert simple_prompt_name.model == "provider/name"
202+
assert simple_prompt["provider/name"].model == "provider/name"
203+
200204
assert simple_prompt("test") == "test"
201205
assert simple_prompt_name("test") == "name: test"
202206

0 commit comments

Comments
 (0)