Skip to content

Commit f387e03

Browse files
hynky1999clefourrier
authored andcommitted
Adds template for translation tasks (#391)
* implement tranlsation prompt * add small coment about tranlsation prompt * change formatting to reformat language dependant parts --------- Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com>
1 parent 51c980d commit f387e03

File tree

3 files changed

+286
-3
lines changed

3 files changed

+286
-3
lines changed

src/lighteval/tasks/templates/continuation.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def get_continuation_prompt_function(
8888
language: Language,
8989
adapter: Callable[[dict], ContinuationInput | None] | ContinuationDictAdapter,
9090
formulation: Formulation = MCFFormulation(),
91+
fix_formatting: bool = True,
9192
):
9293
"""
9394
Create a templated prompt function for a Continuation task.
@@ -120,6 +121,7 @@ def get_continuation_prompt_function(
120121
adapter (Callable[[dict], ContinuationInput] | ContinuationDictAdapter): Either a function that takes a dataset row and returns a ContinuationInput, or a dictionary with keys corresponding to the field names in the dataset row.
121122
Note: Both ContinuationDictAdapter and ContinuationInput are TypeDicts, this means that the caller provides dictionary and doesn't initialize any class!
122123
formulation (Formulation, optional): The formulation (MCF/Hybrid/CF) to use for the task. Defaults to MCFFormulation().
124+
fix_formatting (bool, optional): Whether to fix the formatting of the text by capitalizing and fixing punctuation based on language. If False, the text will be used as-is. Defaults to True.
123125
Returns:
124126
Callable: A function that generates Continuation prompt based on the given parameters.
125127
"""
@@ -134,12 +136,17 @@ def prepare_prompt(line: dict):
134136
instruction_val = cont_input.get("instruction")
135137
instruction = f"{instruction_val}\n" if instruction_val else ""
136138

137-
context = f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}"
138-
continuations = cont_input["continuations"]
139+
context = (
140+
f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}"
141+
if fix_formatting
142+
else cont_input["context"]
143+
)
139144

140145
continuations = [
141146
fix_capitalization(context, fix_ending_punct(continuation, translation_literals), translation_literals)
142-
for continuation in continuations
147+
if fix_formatting
148+
else continuation
149+
for continuation in cont_input["continuations"]
143150
]
144151

145152
return cont_input, instruction, context, continuations
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
from typing import Callable
24+
25+
from langcodes import standardize_tag
26+
from typing_extensions import NotRequired, TypedDict
27+
28+
from lighteval.tasks.templates.continuation import get_continuation_prompt_function
29+
from lighteval.tasks.templates.multichoice import create_adapter_from_dict
30+
from lighteval.tasks.templates.utils.formatting_utils import capitalize, fix_ending_punct
31+
from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation
32+
from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
33+
from lighteval.utils.language import Language
34+
from lighteval.utils.utils import as_list
35+
36+
37+
# Template chosen so that it's not very language-dependent, as it's not clear whether one should use the target or source language.
38+
# It's also the best template based on https://arxiv.org/pdf/2301.07069.
39+
40+
41+
TRANSLATION_CONTEXT = "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}"
42+
43+
44+
# Defined for type hinting only
45+
class TranslationInput(TypedDict):
46+
"""
47+
Input for the Translation task.
48+
Args:
49+
source_text: The source text to be translated
50+
target_text: The target text to be translated
51+
instruction (optional): The instruction of the Translation task (e.g. Translate the following text to Turkish)
52+
"""
53+
54+
source_text: str
55+
target_text: str | list[str]
56+
gold_idx: NotRequired[int | list[int]]
57+
instruction: NotRequired[str]
58+
59+
60+
class TranslationAdapter(TypedDict):
61+
"""
62+
Adapter for mapping from the dataset row into the TranslationInput format.
63+
Args:
64+
source_text: Column name in the row that contains the source text to be translated
65+
target_text: Column name in the row that contains the target text to be translated
66+
instruction (optional): Column name in the row that contains the instruction of the task (e.g. Translate the following text to Turkish)
67+
"""
68+
69+
source_text: str
70+
target_text: str
71+
gold_idx: NotRequired[int | list[int]]
72+
instruction: NotRequired[str]
73+
74+
75+
def get_translation_prompt_function(
76+
source_language: Language,
77+
target_language: Language,
78+
adapter: Callable[[dict], TranslationInput | None] | TranslationAdapter,
79+
formulation: Formulation = MCFFormulation(),
80+
):
81+
"""
82+
Create a templated prompt function for a Translation task.
83+
Example tasks:
84+
- WMT2016
85+
- WMT2017
86+
87+
Format:
88+
*CF*
89+
EN: How are you? TR: | Nasılsın?
90+
91+
*Hybrid*
92+
EN: How are you? TR:
93+
A. Nasılsın?
94+
B. Jak se máš?
95+
Answer: | Nasılsın?/Jak se máš?
96+
97+
*MCF*
98+
EN: How are you? TR:
99+
A. Nasılsın?
100+
B. Jak se máš?
101+
Answer: | A/B
102+
103+
Args:
104+
adapter (Callable[[dict], TranslationInput] | TranslationAdapter): Either a function that takes a dataset row and returns a TranslationInput, or a dictionary with keys corresponding to the field names in the dataset row.
105+
Note: Both TranslationAdapter and TranslationInput are TypeDicts, this means that the caller provides dictionary and doesn't initialize any class!
106+
formulation (Formulation, optional): The formulation to use for the task. Defaults to MCFFormulation().
107+
Returns:
108+
Callable: A function that generates Translation prompts based on the given parameters.
109+
"""
110+
adapter_fn = create_adapter_from_dict(adapter)
111+
continuation_prompt_fn = get_continuation_prompt_function(
112+
Language.ENGLISH,
113+
{"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"},
114+
formulation,
115+
fix_formatting=False,
116+
)
117+
source_translation_literals = TRANSLATION_LITERALS[source_language]
118+
target_translation_literals = TRANSLATION_LITERALS[target_language]
119+
120+
source_label_string = standardize_tag(source_language.value).upper()
121+
target_label_string = standardize_tag(target_language.value).upper()
122+
123+
def translation_prompt(
124+
line: dict,
125+
task_name: str,
126+
):
127+
input_data = adapter_fn(line)
128+
if input_data is None:
129+
return None
130+
131+
source_text = capitalize(fix_ending_punct(input_data["source_text"], source_translation_literals))
132+
133+
context = TRANSLATION_CONTEXT.format(
134+
source_label=source_label_string,
135+
source_text=source_text,
136+
target_label=target_label_string,
137+
colon=":",
138+
sentence_space=" ",
139+
)
140+
141+
continuations = [
142+
capitalize(fix_ending_punct(text, target_translation_literals))
143+
for text in as_list(input_data["target_text"])
144+
]
145+
146+
return continuation_prompt_fn(
147+
{
148+
"instruction": input_data.get("instruction", ""),
149+
"context": context,
150+
"continuations": continuations,
151+
"gold_idx": input_data.get("gold_idx", list(range(len(continuations)))),
152+
},
153+
task_name,
154+
)
155+
156+
return translation_prompt
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
24+
from lighteval.tasks.templates.translation import get_translation_prompt_function
25+
from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
26+
from lighteval.utils.language import Language
27+
28+
29+
def test_translation_prompt_cf():
30+
"""
31+
Tests that translation prompt function works correctly for CF formulation.
32+
"""
33+
test_input = {
34+
"source_text": "Ahoj, jak se máš?",
35+
"target_text": "Bonjour, comment allez-vous?",
36+
}
37+
38+
prompt_fn = get_translation_prompt_function(
39+
source_language=Language.CZECH,
40+
target_language=Language.FRENCH,
41+
adapter=lambda x: {
42+
"source_text": x["source_text"],
43+
"target_text": x["target_text"],
44+
},
45+
formulation=CFFormulation(),
46+
)
47+
48+
doc = prompt_fn(test_input, "test_task")
49+
assert doc is not None
50+
51+
assert doc.query == "CS: Ahoj, jak se máš? FR:"
52+
assert doc.unconditioned_query == ""
53+
assert doc.choices == [" Bonjour, comment allez-vous?"]
54+
assert doc.gold_index == [0]
55+
56+
57+
def test_translation_prompt_mcf():
58+
"""
59+
Tests that translation prompt function works correctly for MCF formulation.
60+
"""
61+
test_input = {
62+
"source_text": "Ahoj, jak se máš?",
63+
"target_text": ["Bonjour, comment allez-vous?", "Ciao, come stai?"],
64+
}
65+
66+
prompt_fn = get_translation_prompt_function(
67+
source_language=Language.CZECH,
68+
target_language=Language.FRENCH,
69+
adapter=lambda x: {
70+
"source_text": x["source_text"],
71+
"target_text": x["target_text"],
72+
"gold_idx": 0,
73+
},
74+
formulation=MCFFormulation(),
75+
)
76+
77+
doc = prompt_fn(test_input, "test_task")
78+
assert doc is not None
79+
80+
assert (
81+
doc.query
82+
== """\
83+
CS: Ahoj, jak se máš? FR:
84+
A. Bonjour, comment allez-vous?
85+
B. Ciao, come stai?
86+
Answer:\
87+
"""
88+
)
89+
assert doc.unconditioned_query == "Answer:"
90+
assert doc.choices == [" A", " B"]
91+
assert doc.gold_index == [0]
92+
93+
94+
def test_translation_prompt_cf_formatting():
95+
"""
96+
Tests that translation prompt function works correctly for CF formulation with formatting.
97+
"""
98+
test_input = {
99+
"source_text": "How are you?",
100+
"target_text": ["你好吗?"],
101+
}
102+
103+
prompt_fn = get_translation_prompt_function(
104+
source_language=Language.ENGLISH,
105+
target_language=Language.CHINESE,
106+
adapter=lambda x: {
107+
"source_text": x["source_text"],
108+
"target_text": x["target_text"],
109+
"gold_idx": 0,
110+
},
111+
formulation=CFFormulation(),
112+
)
113+
114+
doc = prompt_fn(test_input, "test_task")
115+
assert doc is not None
116+
117+
assert doc.query == "EN: How are you? ZH:"
118+
assert doc.unconditioned_query == ""
119+
assert doc.choices == [" 你好吗?"]
120+
assert doc.gold_index == [0]

0 commit comments

Comments
 (0)