Skip to content

Commit 679401d

Browse files
authored
Added tests for prompt recipe (#45)
* added test for jinja recipe * prettify, fixed warning
1 parent e4ef335 commit 679401d

File tree

6 files changed

+191
-262
lines changed

6 files changed

+191
-262
lines changed

examples/zero_shot_prompting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def __init__(
3838
self.recipe = smashed.recipes.JinjaRecipe(
3939
tokenizer=self.tokenizer,
4040
jinja_template=template,
41-
max_source_content_length=max_source_content_length,
42-
max_target_content_length=max_target_content_length,
41+
max_source_length_per_shot=max_source_content_length,
42+
max_target_length_per_shot=max_target_content_length,
4343
) >> smashed.recipes.CollatorRecipe(
4444
tokenizer=self.tokenizer,
4545
device=device,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.15.3"
3+
version = "0.15.4"
44
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
55
authors = [
66
{name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" },

src/smashed/mappers/prompting.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def _find_truncated_lens_longest(
245245
redistributed_extra_len = cls._find_truncated_lens_uniform(
246246
lens=longer_than_average,
247247
max_len=extra_len_to_redistribute,
248-
# max_length=max_len,
249248
)
250249

251250
# we figure out new lengths by adding the redistributed extra length

src/smashed/recipes/promptsource.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def __init__(
1717
tokenizer: PreTrainedTokenizerBase,
1818
jinja_template: str,
1919
num_shots: int = 0,
20-
max_source_content_length: Optional[int] = None,
21-
max_target_content_length: Optional[int] = None,
20+
max_source_length_per_shot: Optional[int] = None,
21+
max_target_length_per_shot: Optional[int] = None,
2222
truncation_strategy: Literal["longest", "uniform"] = "longest",
2323
use_words: bool = True,
2424
source_fields: Optional[Sequence[str]] = None,
@@ -35,14 +35,15 @@ def __init__(
3535
the source and target; we use promptsource to parse the
3636
template and extract the source and target fields; please
3737
see the promptsource documentation for more details.
38-
max_source_content_length (Optional[int], optional): the maximum
39-
length of the source content (i.e., the content that is given
40-
as input to the model). If not provided, no truncation will
41-
be performed. Defaults to None.
38+
max_source_length_per_shot (Optional[int], optional): the maximum
39+
length of all the fields that are part of the source in a
40+
prompting shot. If not provided, no truncation will be
41+
performed. Defaults to None
4242
max_target_content_length (Optional[int], optional): the maximum
43-
length of the target content (i.e., the content that is
44-
expected as output from the model). If not provided, no
45-
truncation will be performed. Defaults to None.
43+
length of all the fields that are part of the target in a
44+
prompting shot (that is, the text the model is asked to
45+
generate). If not provided, no truncation will be performed.
46+
Defaults to None.
4647
truncation_strategy ("longest" or "uniform"], optional): how to
4748
perform truncation if the source or target content is longer
4849
than the maximum length. If "longest", the longest fields
@@ -124,17 +125,42 @@ def __init__(
124125
# if we don't use words, we just use the length of the prompt
125126
# in characters.
126127
length_src_prompt = len(source_text)
127-
length_tgt_prompt = len(target_text)
128+
# for target, we actually take the max in case there are multiple,
129+
# and 0 if there are none.
130+
length_tgt_prompt = max([len(t) for t in target_text] or [0])
131+
132+
# one liner to round to ceil. avoid import of math.ceil
133+
def ceil(x):
134+
return int(x + (1 if x % 1 else 0)) # noqa: E731
128135

129-
if max_source_content_length is not None:
136+
if max_source_length_per_shot is not None:
130137
# in case a max length for the source is provided, we need to
131-
# truncate; first, we decrease the max length by the length of
132-
# prompt text.
133-
max_source_content_length -= length_src_prompt
138+
# truncate. The total max_length for source data in each shot
139+
# needs to be reduce by (a) the length of the target prompt
140+
# text when doing few-shot, and (b) the length of text of
141+
# the prompt.
142+
#
143+
# For both (a) and (b), we need to distribute the length by
144+
# the number of shorts:
145+
# (a): recall that each prompt will contain n shots + the
146+
# prompt for the sequence we care about. So when doing
147+
# n shot, we are adding n target sequences, but are
148+
# truncating n + 1 target sequences. Therefore, we multiply
149+
# target length by n but divide by (n + 1)
150+
# (b): the text that is part of the prompt but is not variables
151+
# (e.g., instructions) must be divided over n + 1 sources.
152+
actual_source_context_length = (
153+
max_source_length_per_shot
154+
- ceil(
155+
(max_target_length_per_shot or 0)
156+
* (num_shots / (num_shots + 1))
157+
)
158+
- ceil(length_src_prompt / (num_shots + 1))
159+
)
134160

135161
# we raise if the max length is less than one after accounting
136162
# for the length of the prompt text.
137-
if max_source_content_length < 1:
163+
if actual_source_context_length < 1:
138164
raise ValueError(
139165
f"max_source_content_length must be at least equal to "
140166
f"the length of the source prompt ({length_src_prompt})!"
@@ -144,17 +170,17 @@ def __init__(
144170
self.chain(
145171
TruncateMultipleFieldsMapper(
146172
fields_to_truncate=source_fields,
147-
max_length=max_source_content_length,
173+
max_length=actual_source_context_length,
148174
strategy=truncation_strategy,
149175
)
150176
)
151177

152-
if len(target_text) > 0 and max_target_content_length:
178+
if len(target_text) > 0 and max_target_length_per_shot:
153179
# we operate here in the same way as for the source, but we
154180
# only do it if there is a target prompt.
155-
max_target_content_length -= length_tgt_prompt
181+
max_target_length_per_shot -= length_tgt_prompt
156182

157-
if max_target_content_length < 1:
183+
if max_target_length_per_shot < 1:
158184
raise ValueError(
159185
f"max_target_content_length must be at least equal to "
160186
f"the length of the target prompt ({length_tgt_prompt})!"
@@ -163,7 +189,7 @@ def __init__(
163189
self.chain(
164190
TruncateMultipleFieldsMapper(
165191
fields_to_truncate=target_fields,
166-
max_length=max_target_content_length,
192+
max_length=max_target_length_per_shot,
167193
strategy=truncation_strategy,
168194
)
169195
)

tests/test_promptsource.py

Lines changed: 52 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,44 @@
11
import unittest
22

3-
from transformers.models.auto import AutoTokenizer
4-
53
from smashed.mappers.promptsource import (
64
FewShotJinjaMapper,
75
JinjaMapper,
86
PromptsourceMapper,
97
SingleTransformPromptsourceMixin,
108
)
11-
from smashed.recipes.promptsource import JinjaRecipe
9+
10+
FEW_SHOT_DATASET = [
11+
{
12+
"question": "Who is Bill Gates?",
13+
"answer": "Bill Gates is a billionaire.",
14+
},
15+
{
16+
"question": "who is john lennon?",
17+
"answer": "John Lennon was a musician.",
18+
},
19+
{
20+
"question": "who is john doe?",
21+
"answer": "John Doe is a fictional character.",
22+
},
23+
{
24+
"question": "who is goldie hawn?",
25+
"answer": "Goldie Hawn is an actress.",
26+
},
27+
{
28+
"question": "who is ru paul?",
29+
"answer": "Ru Paul is a drag queen.",
30+
},
31+
]
32+
33+
FEW_SHOT_PROMPT = (
34+
"{% for shot in __shots__ %}"
35+
"Q: {{shot.question}}\n"
36+
"A: {{shot.answer}}\n"
37+
"\n"
38+
"{% endfor %}"
39+
"Q: {{question}}\n"
40+
"A: </s>|||{{answer}}"
41+
)
1242

1343

1444
class TestPromptsource(unittest.TestCase):
@@ -60,123 +90,46 @@ def test_dataset_prompt_source_mapper(self):
6090
mapped_dataset2 = mapper2.map(dataset, remove_columns=True)
6191
self.assertEqual(mapped_dataset, mapped_dataset2)
6292

63-
def test_promptsource_recipe(self):
64-
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
65-
66-
recipe = JinjaRecipe(
67-
tokenizer=AutoTokenizer.from_pretrained("bert-base-cased"),
68-
jinja_template="Q: {{question}}\nC: {{context}}\nA: |||{{answer}}",
69-
max_source_content_length=15,
70-
max_target_content_length=5,
71-
)
72-
dataset = [
73-
{
74-
"question": "What is the capital of France?",
75-
"context": "Paris is the capital of " + ("France " * 10),
76-
"answer": "Paris " * 10,
77-
}
78-
]
79-
80-
mapped_dataset, *_ = recipe.map(dataset)
81-
82-
self.assertEqual(
83-
tokenizer.decode(mapped_dataset["input_ids"]),
84-
(
85-
"Q : What is the capital of France? "
86-
"C : Paris is the capital of France "
87-
"A :"
88-
),
89-
)
90-
91-
self.assertEqual(
92-
tokenizer.decode(mapped_dataset["labels"]),
93-
"Paris Paris Paris Paris Paris",
94-
)
95-
96-
def _few_shot_data_prompt(self):
97-
dataset = [
98-
{
99-
"question": "Who is Bill Gates?",
100-
"answer": "Bill Gates is a billionaire.",
101-
},
102-
{
103-
"question": "who is john lennon?",
104-
"answer": "John Lennon was a musician.",
105-
},
106-
{
107-
"question": "who is john doe?",
108-
"answer": "John Doe is a fictional character.",
109-
},
110-
{
111-
"question": "who is goldie hawn?",
112-
"answer": "Goldie Hawn is an actress.",
113-
},
114-
{
115-
"question": "who is ru paul?",
116-
"answer": "Ru Paul is a drag queen.",
117-
},
118-
]
119-
jinja_prompt = (
120-
"{% for shot in __shots__ %}"
121-
"Q: {{shot.question}}\n"
122-
"A: {{shot.answer}}\n"
123-
"\n"
124-
"{% endfor %}"
125-
"Q: {{question}}\n"
126-
"A: </s>|||{{answer}}"
127-
)
128-
129-
return dataset, jinja_prompt
130-
13193
def test_fewshot_jinja(self):
94+
mapper = FewShotJinjaMapper(jinja=FEW_SHOT_PROMPT, num_shots=2)
13295

133-
dataset, jinja_prompt = self._few_shot_data_prompt()
134-
135-
mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=2)
136-
137-
mapped_dataset = mapper.map(dataset)
96+
mapped_dataset = mapper.map(FEW_SHOT_DATASET)
13897

13998
self.assertEqual(len(mapped_dataset), 1)
14099

141100
self.assertEqual(
142101
mapped_dataset[0]["source"],
143102
(
144-
"Q: Who is Bill Gates?\nA: Bill Gates is a billionaire.\n\n"
145-
"Q: who is john lennon?\nA: John Lennon was a musician.\n\n"
146-
"Q: who is john doe?\nA: </s>"
103+
f"Q: {FEW_SHOT_DATASET[0]['question']}\n"
104+
f"A: {FEW_SHOT_DATASET[0]['answer']}\n\n"
105+
f"Q: {FEW_SHOT_DATASET[1]['question']}\n"
106+
f"A: {FEW_SHOT_DATASET[1]['answer']}\n\n"
107+
f"Q: {FEW_SHOT_DATASET[2]['question']}\nA: </s>"
147108
),
148109
)
149110

150111
self.assertEqual(
151112
mapped_dataset[0]["target"],
152-
"John Doe is a fictional character.",
113+
FEW_SHOT_DATASET[2]["answer"],
153114
)
154115

155116
def test_few_shot_jinja_zero_shots(self):
156-
dataset, jinja_prompt = self._few_shot_data_prompt()
117+
mapper = FewShotJinjaMapper(jinja=FEW_SHOT_PROMPT, num_shots=0)
157118

158-
mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=0)
159-
160-
mapped_dataset = mapper.map(dataset)
119+
mapped_dataset = mapper.map(FEW_SHOT_DATASET)
161120

162121
self.assertEqual(len(mapped_dataset), 5)
163122

164-
self.assertEqual(
165-
mapped_dataset[0]["source"], "Q: Who is Bill Gates?\nA: </s>"
166-
)
167-
168-
self.assertEqual(
169-
mapped_dataset[0]["target"],
170-
"Bill Gates is a billionaire.",
171-
)
123+
for i in range(5):
124+
self.assertEqual(
125+
mapped_dataset[i]["source"],
126+
f"Q: {FEW_SHOT_DATASET[i]['question']}\nA: </s>",
127+
)
172128

173-
self.assertEqual(
174-
mapped_dataset[1]["source"], "Q: who is john lennon?\nA: </s>"
175-
)
176-
self.assertEqual(
177-
mapped_dataset[1]["target"],
178-
"John Lennon was a musician.",
179-
)
129+
self.assertEqual(
130+
mapped_dataset[i]["target"],
131+
FEW_SHOT_DATASET[i]["answer"],
132+
)
180133

181134
def test_few_shot_exception(self):
182135
with self.assertRaises(KeyError):

0 commit comments

Comments
 (0)