Skip to content

Commit c06b131

Browse files
feat: changed summary to new prompt (#1469)
Co-authored-by: Shahules786 <[email protected]>
1 parent a4b1912 commit c06b131

File tree

1 file changed

+126
-172
lines changed

1 file changed

+126
-172
lines changed

src/ragas/metrics/_summarization.py

Lines changed: 126 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -5,140 +5,134 @@
55
from dataclasses import dataclass, field
66
from typing import Dict
77

8-
from langchain.pydantic_v1 import BaseModel
8+
from pydantic import BaseModel
99

1010
from ragas.dataset_schema import SingleTurnSample
11-
from ragas.llms.output_parser import RagasOutputParserOld, get_json_format_instructions
12-
from ragas.llms.prompt import Prompt
1311
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric
12+
from ragas.prompt import PydanticPrompt, StringIO
1413

1514
if t.TYPE_CHECKING:
1615
from langchain.callbacks.base import Callbacks
1716

1817
logger = logging.getLogger(__name__)
1918

2019

21-
class ExtractKeyphrasesResponse(BaseModel):
20+
class ExtractedKeyphrases(BaseModel):
2221
keyphrases: t.List[str]
2322

2423

25-
class GenerateQuestionsResponse(BaseModel):
24+
class QuestionsGenerated(BaseModel):
2625
questions: t.List[str]
2726

2827

29-
class GenerateAnswersResponse(BaseModel):
28+
class AnswersGenerated(BaseModel):
3029
answers: t.List[str]
3130

3231

33-
_output_instructions_question_generation = get_json_format_instructions(
34-
pydantic_object=GenerateQuestionsResponse # type: ignore
35-
)
36-
_output_instructions_answer_generation = get_json_format_instructions(
37-
pydantic_object=GenerateAnswersResponse # type: ignore
38-
)
39-
_output_instructions_keyphrase_extraction = get_json_format_instructions(
40-
pydantic_object=ExtractKeyphrasesResponse # type: ignore
41-
)
42-
_output_parser_question_generation = RagasOutputParserOld(
43-
pydantic_object=GenerateQuestionsResponse
44-
)
45-
_output_parser_answer_generation = RagasOutputParserOld(
46-
pydantic_object=GenerateAnswersResponse
47-
)
48-
_output_parser_keyphrase_extraction = RagasOutputParserOld(
49-
pydantic_object=ExtractKeyphrasesResponse
50-
)
51-
52-
53-
TEXT_EXTRACT_KEYPHRASES = Prompt(
54-
name="text_extract_keyphrases",
55-
instruction="Extract the keyphrases essential for summarizing the text.",
56-
output_format_instruction=_output_instructions_keyphrase_extraction,
57-
input_keys=["text"],
58-
output_key="keyphrases",
59-
output_type="json",
60-
examples=[
61-
{
62-
"text": """JPMorgan Chase & Co. is an American multinational finance company headquartered in New York City. It is the largest bank in the United States and the world's largest by market capitalization as of 2023. Founded in 1799, it is a major provider of investment banking services, with US$3.9 trillion in total assets, and ranked #1 in the Forbes Global 2000 ranking in 2023.""",
63-
"keyphrases": [
64-
"JPMorgan Chase & Co.",
65-
"American multinational finance company",
66-
"headquartered in New York City",
67-
"largest bank in the United States",
68-
"world's largest bank by market capitalization",
69-
"founded in 1799",
70-
"major provider of investment banking services",
71-
"US$3.9 trillion in total assets",
72-
"ranked #1 in Forbes Global 2000 ranking",
73-
],
74-
}
75-
],
76-
)
77-
78-
79-
TEXT_GENERATE_QUESTIONS = Prompt(
80-
name="text_generate_questions",
81-
instruction="Based on the given text and keyphrases, generate closed-ended questions that can be answered with '1' if the question can be answered using the text, or '0' if it cannot. The questions should ALWAYS result in a '1' based on the given text.",
82-
output_format_instruction=_output_instructions_question_generation,
83-
input_keys=["text", "keyphrases"],
84-
output_key="questions",
85-
output_type="json",
86-
examples=[
87-
{
88-
"text": """JPMorgan Chase & Co. is an American multinational finance company headquartered in New York City. It is the largest bank in the United States and the world's largest by market capitalization as of 2023. Founded in 1799, it is a major provider of investment banking services, with US$3.9 trillion in total assets, and ranked #1 in the Forbes Global 2000 ranking in 2023.""",
89-
"keyphrases": [
90-
"JPMorgan Chase & Co.",
91-
"American multinational finance company",
92-
"headquartered in New York City",
93-
"largest bank in the United States",
94-
"world's largest bank by market capitalization",
95-
"founded in 1799",
96-
"major provider of investment banking services",
97-
"US$3.9 trillion in total assets",
98-
"ranked #1 in Forbes Global 2000 ranking",
99-
],
100-
"questions": [
101-
"Is JPMorgan Chase & Co. an American multinational finance company?",
102-
"Is JPMorgan Chase & Co. headquartered in New York City?",
103-
"Is JPMorgan Chase & Co. the largest bank in the United States?",
104-
"Is JPMorgan Chase & Co. the world's largest bank by market capitalization as of 2023?",
105-
"Was JPMorgan Chase & Co. founded in 1799?",
106-
"Is JPMorgan Chase & Co. a major provider of investment banking services?",
107-
"Does JPMorgan Chase & Co. have US$3.9 trillion in total assets?",
108-
"Was JPMorgan Chase & Co. ranked #1 in the Forbes Global 2000 ranking in 2023?",
109-
],
110-
}
111-
],
112-
)
113-
114-
115-
TEXT_GENERATE_ANSWERS = Prompt(
116-
name="text_generate_answers",
117-
instruction="Based on the list of close-ended '1' or '0' questions, generate a JSON with key 'answers', which is a list of strings that determines whether the provided summary contains sufficient information to answer EACH question. Answers should STRICTLY be either '1' or '0'. Answer '0' if the provided summary does not contain enough information to answer the question and answer '1' if the provided summary can answer the question.",
118-
output_format_instruction=_output_instructions_answer_generation,
119-
input_keys=["summary", "questions"],
120-
output_key="answers",
121-
output_type="json",
122-
examples=[
123-
{
124-
"summary": """JPMorgan Chase & Co., headquartered in New York City, is the largest bank in the US and the world's largest by market capitalization as of 2023. Founded in 1799, it offers extensive investment, private, asset management, and retail banking services, and has $3.9 trillion in assets, making it the fifth-largest bank globally. It operates the world's largest investment bank by revenue and was ranked #1 in the 2023 Forbes Global 2000.""",
125-
"questions": [
126-
"Is JPMorgan Chase & Co. an American multinational finance company?",
127-
"Is JPMorgan Chase & Co. headquartered in New York City?",
128-
"Is JPMorgan Chase & Co. the largest bank in the United States?",
129-
"Is JPMorgan Chase & Co. the world's largest bank by market capitalization as of 2023?",
130-
"Is JPMorgan Chase & Co. considered systemically important by the Financial Stability Board?",
131-
"Was JPMorgan Chase & Co. founded in 1799 as the Chase Manhattan Company?",
132-
"Is JPMorgan Chase & Co. a major provider of investment banking services?",
133-
"Is JPMorgan Chase & Co. the fifth-largest bank in the world by assets?",
134-
"Does JPMorgan Chase & Co. operate the largest investment bank by revenue?",
135-
"Was JPMorgan Chase & Co. ranked #1 in the Forbes Global 2000 ranking?",
136-
"Does JPMorgan Chase & Co. provide investment banking services?",
137-
],
138-
"answers": ["0", "1", "1", "1", "0", "0", "1", "1", "1", "1", "1"],
139-
}
140-
],
141-
)
32+
class ExtractKeyphrasePrompt(PydanticPrompt[StringIO, ExtractedKeyphrases]):
33+
name: str = "extract_keyphrases"
34+
instruction: str = "Extract keyphrases of type: Person, Organization, Location, Date/Time, Monetary Values, and Percentages."
35+
input_model = StringIO
36+
output_model = ExtractedKeyphrases
37+
examples: t.List[t.Tuple[StringIO, ExtractedKeyphrases]] = [
38+
(
39+
StringIO(
40+
text="Apple Inc. is a technology company based in Cupertino, California. Founded by Steve Jobs in 1976, it reached a market capitalization of $3 trillion in 2023."
41+
),
42+
ExtractedKeyphrases(
43+
keyphrases=[
44+
"Apple Inc.",
45+
"Cupertino, California",
46+
"Steve Jobs",
47+
"1976",
48+
"$3 trillion",
49+
"2023",
50+
]
51+
),
52+
)
53+
]
54+
55+
56+
class GenerateQuestionsPromptInput(BaseModel):
57+
text: str
58+
keyphrases: t.List[str]
59+
60+
61+
class GenerateQuestionsPrompt(
62+
PydanticPrompt[GenerateQuestionsPromptInput, QuestionsGenerated]
63+
):
64+
name: str = "generate_questions"
65+
instruction: str = "Based on the given text and keyphrases, generate closed-ended questions that can be answered with '1' if the question can be answered using the text, or '0' if it cannot. The questions should ALWAYS result in a '1' based on the given text."
66+
input_model = GenerateQuestionsPromptInput
67+
output_model = QuestionsGenerated
68+
examples: t.List[t.Tuple[GenerateQuestionsPromptInput, QuestionsGenerated]] = [
69+
(
70+
GenerateQuestionsPromptInput(
71+
text="Apple Inc. is a technology company based in Cupertino, California. Founded by Steve Jobs in 1976, it reached a market capitalization of $3 trillion in 2023.",
72+
keyphrases=[
73+
"Apple Inc.",
74+
"Cupertino, California",
75+
"Steve Jobs",
76+
"1976",
77+
"$3 trillion",
78+
"2023",
79+
],
80+
),
81+
QuestionsGenerated(
82+
questions=[
83+
"Is Apple Inc. a technology company?",
84+
"Is Apple Inc. based in Cupertino, California?",
85+
"Was Apple Inc. founded by Steve Jobs?",
86+
"Was Apple Inc. founded in 1976?",
87+
"Did Apple Inc. reach a market capitalization of $3 trillion?",
88+
"Did Apple Inc. reach a market capitalization of $3 trillion in 2023?",
89+
]
90+
),
91+
)
92+
]
93+
94+
95+
class SummaryAndQuestions(BaseModel):
96+
summary: str
97+
questions: t.List[str]
98+
99+
100+
class GenerateAnswersPrompt(PydanticPrompt[SummaryAndQuestions, AnswersGenerated]):
101+
name: str = "generate_answers"
102+
instruction: str = "Based on the list of close-ended '1' or '0' questions, generate a JSON with key 'answers', which is a list of strings that determines whether the provided summary contains sufficient information to answer EACH question. Answers should STRICTLY be either '1' or '0'. Answer '0' if the provided summary does not contain enough information to answer the question and answer '1' if the provided summary can answer the question."
103+
input_model = SummaryAndQuestions
104+
output_model = AnswersGenerated
105+
examples: t.List[t.Tuple[SummaryAndQuestions, AnswersGenerated]] = [
106+
(
107+
SummaryAndQuestions(
108+
summary="Apple Inc. is a technology company based in Cupertino, California. Founded by Steve Jobs in 1976, it reached a market capitalization of $3 trillion in 2023.",
109+
questions=[
110+
"Is Apple Inc. a technology company?",
111+
"Is Apple Inc. based in Cupertino, California?",
112+
"Was Apple Inc. founded by Steve Jobs?",
113+
"Was Apple Inc. founded in 1976?",
114+
"Did Apple Inc. reach a market capitalization of $3 trillion?",
115+
"Did Apple Inc. reach a market capitalization of $3 trillion in 2023?",
116+
"Is Apple Inc. a major software company?",
117+
"Is Apple Inc. known for the iPhone?",
118+
"Was Steve Jobs the co-founder of Apple Inc.?",
119+
],
120+
),
121+
AnswersGenerated(
122+
answers=[
123+
"1",
124+
"1",
125+
"1",
126+
"1",
127+
"1",
128+
"1",
129+
"0",
130+
"0",
131+
"1",
132+
]
133+
),
134+
)
135+
]
142136

143137

144138
@dataclass
@@ -155,14 +149,14 @@ class SummarizationScore(MetricWithLLM, SingleTurnMetric):
155149
}
156150
)
157151
coeff: float = 0.5
158-
question_generation_prompt: Prompt = field(
159-
default_factory=lambda: TEXT_GENERATE_QUESTIONS
152+
question_generation_prompt: PydanticPrompt = field(
153+
default_factory=GenerateQuestionsPrompt
160154
)
161-
answer_generation_prompt: Prompt = field(
162-
default_factory=lambda: TEXT_GENERATE_ANSWERS
155+
answer_generation_prompt: PydanticPrompt = field(
156+
default_factory=GenerateAnswersPrompt
163157
)
164-
extract_keyphrases_prompt: Prompt = field(
165-
default_factory=lambda: TEXT_EXTRACT_KEYPHRASES
158+
extract_keyphrases_prompt: PydanticPrompt = field(
159+
default_factory=ExtractKeyphrasePrompt
166160
)
167161

168162
async def _single_turn_ascore(
@@ -201,17 +195,11 @@ def _compute_conciseness_score(self, text, summary) -> float:
201195

202196
async def _extract_keyphrases(self, text: str, callbacks: Callbacks) -> t.List[str]:
203197
assert self.llm is not None, "LLM is not initialized"
204-
p_value = self.extract_keyphrases_prompt.format(text=text)
205-
result = await self.llm.generate(
206-
prompt=p_value,
207-
callbacks=callbacks,
208-
)
209-
result_text = result.generations[0][0].text
210-
response = await _output_parser_keyphrase_extraction.aparse(
211-
result_text, p_value, self.llm, self.max_retries
212-
)
213198

214-
if not response or not response.keyphrases:
199+
response: ExtractedKeyphrases = await self.extract_keyphrases_prompt.generate(
200+
data=StringIO(text=text), llm=self.llm, callbacks=callbacks
201+
)
202+
if not response:
215203
logging.error("No keyphrases generated, unable to calculate the score.")
216204
return []
217205

@@ -221,20 +209,12 @@ async def _get_questions(
221209
self, text: str, keyphrases: list[str], callbacks: Callbacks
222210
) -> t.List[str]:
223211
assert self.llm is not None, "LLM is not initialized"
224-
p_value = self.question_generation_prompt.format(
225-
text=text, keyphrases=keyphrases
226-
)
227-
result = await self.llm.generate(
228-
prompt=p_value,
212+
response: QuestionsGenerated = await self.question_generation_prompt.generate(
213+
data=GenerateQuestionsPromptInput(text=text, keyphrases=keyphrases),
214+
llm=self.llm,
229215
callbacks=callbacks,
230216
)
231-
232-
result_text = result.generations[0][0].text
233-
response = await _output_parser_question_generation.aparse(
234-
result_text, p_value, self.llm, self.max_retries
235-
)
236-
237-
if not response or not response.questions:
217+
if not response:
238218
logging.error("No questions generated, unable to calculate the score.")
239219
return []
240220

@@ -244,38 +224,12 @@ async def _get_answers(
244224
self, questions: t.List[str], summary: str, callbacks: Callbacks
245225
) -> t.List[str]:
246226
assert self.llm is not None, "LLM is not initialized"
247-
p_value = self.answer_generation_prompt.format(
248-
questions=questions, summary=summary
249-
)
250-
result = await self.llm.generate(
251-
prompt=p_value,
227+
response: AnswersGenerated = await self.answer_generation_prompt.generate(
228+
data=SummaryAndQuestions(questions=questions, summary=summary),
229+
llm=self.llm,
252230
callbacks=callbacks,
253231
)
254-
255-
result_text = result.generations[0][0].text
256-
response = await _output_parser_answer_generation.aparse(
257-
result_text, p_value, self.llm, self.max_retries
258-
)
259-
260-
if not response or not response.answers:
261-
logger.error("No answers generated, unable to calculate the score.")
262-
return []
263-
264232
return response.answers
265233

266-
def adapt(self, language: str, cache_dir: str | None = None) -> None:
267-
assert self.llm is not None, "set LLM before use"
268-
269-
logger.info(f"Adapting summarization to {language}")
270-
self.question_generation_prompt = self.question_generation_prompt.adapt(
271-
language, self.llm, cache_dir
272-
)
273-
self.answer_generation_prompt = self.answer_generation_prompt.adapt(
274-
language, self.llm, cache_dir
275-
)
276-
self.answer_generation_prompt = self.answer_generation_prompt.adapt(
277-
language, self.llm, cache_dir
278-
)
279-
280234

281235
summarization_score = SummarizationScore()

0 commit comments

Comments
 (0)