Skip to content

Commit c583f2e

Browse files
committed
Make funtion templates real functions
Prompts remplates are currently contained in the docstring of decorated functions. The main issue with this is that prompt templates cannot be composed. In this commit we instead require users to return the prompt template from the function. The template will then automatically be rendered using the values passed to the function. This is very flexible: some variables can be used inside the functions and not be present in the Jinja2 template that is returned, for instance: ```python import prompts @prompts.template def my_template(a, b): prompt = f'This is a first variable {a}' return prompt + "and a second {{b}}" ```
1 parent 05c9d5e commit c583f2e

File tree

4 files changed

+31
-56
lines changed

4 files changed

+31
-56
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ from prompts import template
2424

2525
@template
2626
def few_shots(instructions, examples, question):
27-
"""{{ instructions }}
27+
return """{{ instructions }}
2828
2929
Examples
3030
--------

docs/reference/template.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ will pass to the prompt function.
3838

3939
@prompts.template
4040
def greetings(name, question):
41-
"""Hello, {{ name }}!
41+
return """Hello, {{ name }}!
4242
{{ question }}
4343
"""
4444

@@ -62,7 +62,7 @@ If a variable is missing in the function's arguments, Jinja2 will throw an `Unde
6262

6363
@prompts.template
6464
def greetings(name):
65-
"""Hello, {{ surname }}!"""
65+
return """Hello, {{ surname }}!"""
6666

6767
prompt = greetings("user")
6868
```
@@ -94,7 +94,7 @@ Prompt functions are functions, and thus can be imported from other modules:
9494

9595
@prompts.template
9696
def greetings(name, question):
97-
"""Hello, {{ name }}!
97+
return """Hello, {{ name }}!
9898
{{ question }}
9999
"""
100100
```
@@ -128,7 +128,7 @@ keys `question` and `answer` to the prompt function:
128128

129129
@prompts.template
130130
def few_shots(instructions, examples, question):
131-
"""{{ instructions }}
131+
return """{{ instructions }}
132132

133133
Examples
134134
--------
@@ -207,12 +207,12 @@ below does not matter for formatting:
207207

208208
@prompts.template
209209
def prompt1():
210-
"""My prompt
210+
return """My prompt
211211
"""
212212

213213
@prompts.template
214214
def prompt2():
215-
"""
215+
return """
216216
My prompt
217217
"""
218218

@@ -236,20 +236,20 @@ Indentation is relative to the second line of the docstring, and leading spaces
236236

237237
@prompts.template
238238
def example1():
239-
"""First line
239+
return """First line
240240
Second line
241241
"""
242242

243243
@prompts.template
244244
def example2():
245-
"""
245+
return """
246246
Second line
247247
Third line
248248
"""
249249

250250
@prompts.template
251251
def example3():
252-
"""
252+
return """
253253
Second line
254254
Third line
255255
"""
@@ -285,7 +285,7 @@ You can use the backslash `\` to break a long line of text. It will render as a
285285

286286
@prompts.template
287287
def example():
288-
"""
288+
return """
289289
Break in \
290290
several lines \
291291
But respect the indentation

prompts/templates.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from dataclasses import dataclass, field
55
from functools import lru_cache
6-
from typing import Callable, Dict, Hashable, Optional, cast
6+
from typing import Callable, Dict, Hashable, Optional
77

88
from jinja2 import Environment, StrictUndefined
99

@@ -15,7 +15,7 @@ class Template:
1515
"""Represents a prompt template.
1616
1717
A prompt template is a callable that, given a Jinja2 template and a set of values,
18-
renders the template using those values. It is recommended to instantiate `Temaplate`
18+
renders the template using those values. It is recommended to instantiate `Template`
1919
using the `template` decorator, which extracts the template from the function's
2020
docstring and its variables from the function's signature.
2121
@@ -40,7 +40,7 @@ class Template:
4040
4141
"""
4242

43-
template: str
43+
fn: Callable
4444
signature: inspect.Signature
4545
model: Optional[str] = None
4646
registry: Dict[str, Callable] = field(default_factory=dict)
@@ -55,10 +55,10 @@ def __call__(self, *args, **kwargs) -> str:
5555
"""
5656
bound_arguments = self.signature.bind(*args, **kwargs)
5757
bound_arguments.apply_defaults()
58-
return render(self.template, self.model, **bound_arguments.arguments)
5958

60-
def __str__(self):
61-
return self.template
59+
template = self.fn(**bound_arguments.arguments)
60+
61+
return render(template, self.model, **bound_arguments.arguments)
6262

6363
def __getitem__(self, model_name: str):
6464
"""Get the prompt template corresponding to a model name.
@@ -104,24 +104,23 @@ def template(fn: Callable) -> Template:
104104
manipulation by providing some degree of encapsulation. It uses the `render`
105105
function internally to render templates.
106106
107-
>>> import outlines
107+
>>> import prompts
108108
>>>
109-
>>> @outlines.prompt
109+
>>> @prompts.template
110110
>>> def build_prompt(question):
111-
... "I have a ${question}"
111+
... return "I have a {{question}}"
112112
...
113113
>>> prompt = build_prompt("How are you?")
114114
115115
This API can also be helpful in an "agent" context where parts of the prompt
116116
are set when the agent is initialized and never modified later. In this situation
117117
we can partially apply the prompt function at initialization.
118118
119-
>>> import outlines
120-
>>> import functools as ft
119+
>>> import prompts
121120
...
122-
>>> @outlines.prompt
121+
>>> @prompts.template
123122
... def solve_task(name: str, objective: str, task: str):
124-
... '''Your name is {{name}}.
123+
... return '''Your name is {{name}}.
125124
.. Your overall objective is to {{objective}}.
126125
... Please solve the following task: {{task}}'''
127126
...
@@ -134,15 +133,7 @@ def template(fn: Callable) -> Template:
134133
"""
135134
signature = inspect.signature(fn)
136135

137-
# The docstring contains the template that will be rendered to be used
138-
# as a prompt to the language model.
139-
docstring = fn.__doc__
140-
if docstring is None:
141-
raise TypeError("Could not find a template in the function's docstring.")
142-
143-
template = cast(str, docstring)
144-
145-
return Template(template, signature)
136+
return Template(fn, signature)
146137

147138

148139
@lru_cache

tests/test_templates.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,8 @@ def test_render_jinja():
129129
def test_prompt_basic():
130130
@prompts.template
131131
def test_tpl(variable):
132-
"""{{variable}} test"""
132+
return """{{variable}} test"""
133133

134-
assert test_tpl.template == "{{variable}} test"
135134
assert list(test_tpl.signature.parameters.keys()) == ["variable"]
136135

137136
with pytest.raises(TypeError):
@@ -145,7 +144,7 @@ def test_tpl(variable):
145144

146145
@prompts.template
147146
def test_single_quote_tpl(variable):
148-
"${variable} test"
147+
return "{{variable}} test"
149148

150149
p = test_tpl("test")
151150
assert p == "test test"
@@ -154,9 +153,8 @@ def test_single_quote_tpl(variable):
154153
def test_prompt_kwargs():
155154
@prompts.template
156155
def test_kwarg_tpl(var, other_var="other"):
157-
"""{{var}} and {{other_var}}"""
156+
return """{{var}} and {{other_var}}"""
158157

159-
assert test_kwarg_tpl.template == "{{var}} and {{other_var}}"
160158
assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"]
161159

162160
p = test_kwarg_tpl("test")
@@ -169,30 +167,16 @@ def test_kwarg_tpl(var, other_var="other"):
169167
assert p == "test and test"
170168

171169

172-
def test_no_prompt():
173-
with pytest.raises(TypeError, match="template"):
174-
175-
@prompts.template
176-
def test_empty(variable):
177-
pass
178-
179-
with pytest.raises(TypeError, match="template"):
180-
181-
@prompts.template
182-
def test_only_code(variable):
183-
return variable
184-
185-
186170
@pytest.mark.filterwarnings("ignore: The model")
187171
def test_dispatch():
188172

189173
@prompts.template
190174
def simple_prompt(query: str):
191-
"""{{ query }}"""
175+
return """{{ query }}"""
192176

193177
@simple_prompt.register("provider/name")
194178
def simple_prompt_name(query: str):
195-
"""name: {{ query }}"""
179+
return """name: {{ query }}"""
196180

197181
assert list(simple_prompt.registry.keys()) == ["provider/name"]
198182
assert callable(simple_prompt)
@@ -214,7 +198,7 @@ def test_special_tokens():
214198

215199
@prompts.template
216200
def simple_prompt(query: str):
217-
"""{{ bos + query + eos }}"""
201+
return """{{ bos + query + eos }}"""
218202

219203
assert simple_prompt("test") == "test"
220204
assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>"
@@ -225,7 +209,7 @@ def test_warn():
225209

226210
@prompts.template
227211
def simple_prompt():
228-
"""test"""
212+
return """test"""
229213

230214
with pytest.warns(UserWarning, match="not present in the special token"):
231215
simple_prompt["non-existent-model"]()

0 commit comments

Comments
 (0)