Skip to content

Commit 29f70cf

Browse files
authored
fixes: handle long context extraction (#1680)
1 parent c729d08 commit 29f70cf

File tree

3 files changed

+67
-24
lines changed

3 files changed

+67
-24
lines changed

docs/extra/components/choose_generator_llm.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
```python
1818
from ragas.llms import LangchainLLMWrapper
19+
from ragas.embeddings import LangchainEmbeddingsWrapper
1920
from langchain_openai import ChatOpenAI
2021
from langchain_openai import OpenAIEmbeddings
2122
generator_llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o"))

src/ragas/testset/transforms/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, field
55

6+
import tiktoken
7+
from tiktoken.core import Encoding
8+
69
from ragas.llms import BaseRagasLLM, llm_factory
710
from ragas.prompt import PromptMixin
811
from ragas.testset.graph import KnowledgeGraph, Node, Relationship
912

13+
DEFAULT_TOKENIZER = tiktoken.get_encoding("o200k_base")
14+
1015
logger = logging.getLogger(__name__)
1116

1217

@@ -188,6 +193,21 @@ async def apply_extract(node: Node):
188193
class LLMBasedExtractor(Extractor, PromptMixin):
189194
llm: BaseRagasLLM = field(default_factory=llm_factory)
190195
merge_if_possible: bool = True
196+
max_token_limit: int = 32000
197+
tokenizer: Encoding = DEFAULT_TOKENIZER
198+
199+
def split_text_by_token_limit(self, text, max_token_limit):
200+
201+
# Tokenize the entire input string
202+
tokens = self.tokenizer.encode(text)
203+
204+
# Split tokens into chunks of max_token_limit or less
205+
chunks = []
206+
for i in range(0, len(tokens), max_token_limit):
207+
chunk_tokens = tokens[i : i + max_token_limit]
208+
chunks.append(self.tokenizer.decode(chunk_tokens))
209+
210+
return chunks
191211

192212

193213
class Splitter(BaseGraphTransformation):

src/ragas/testset/transforms/extractors/llm_based.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ class HeadlinesExtractorPrompt(PydanticPrompt[TextWithExtractionLimit, Headlines
114114
"Introduction",
115115
"Main Concepts",
116116
"Detailed Analysis",
117+
"Subsection: Specialized Techniques"
117118
"Future Directions",
119+
"Conclusion",
118120
],
119121
),
120122
),
@@ -174,14 +176,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
174176
node_text = node.get_property("page_content")
175177
if node_text is None:
176178
return self.property_name, None
177-
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
179+
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
180+
result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0]))
178181
return self.property_name, result.text
179182

180183

181184
@dataclass
182185
class KeyphrasesExtractor(LLMBasedExtractor):
183186
"""
184-
Extracts top 5 keyphrases from the given text.
187+
Extracts top keyphrases from the given text.
185188
186189
Attributes
187190
----------
@@ -199,10 +202,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
199202
node_text = node.get_property("page_content")
200203
if node_text is None:
201204
return self.property_name, None
202-
result = await self.prompt.generate(
203-
self.llm, data=TextWithExtractionLimit(text=node_text, max_num=self.max_num)
204-
)
205-
return self.property_name, result.keyphrases
205+
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
206+
keyphrases = []
207+
for chunk in chunks:
208+
result = await self.prompt.generate(
209+
self.llm, data=TextWithExtractionLimit(text=chunk, max_num=self.max_num)
210+
)
211+
keyphrases.extend(result.keyphrases)
212+
return self.property_name, keyphrases
213+
206214

207215

208216
@dataclass
@@ -225,7 +233,8 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
225233
node_text = node.get_property("page_content")
226234
if node_text is None:
227235
return self.property_name, None
228-
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
236+
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
237+
result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0]))
229238
return self.property_name, result.text
230239

231240

@@ -250,12 +259,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
250259
node_text = node.get_property("page_content")
251260
if node_text is None:
252261
return self.property_name, None
253-
result = await self.prompt.generate(
254-
self.llm, data=TextWithExtractionLimit(text=node_text, max_num=self.max_num)
255-
)
256-
if result is None:
257-
return self.property_name, None
258-
return self.property_name, result.headlines
262+
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
263+
headlines = []
264+
for chunk in chunks:
265+
result = await self.prompt.generate(
266+
self.llm, data=TextWithExtractionLimit(text=chunk, max_num=self.max_num)
267+
)
268+
if result:
269+
headlines.extend(result.headlines)
270+
return self.property_name, headlines
259271

260272

261273
@dataclass
@@ -279,11 +291,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
279291
node_text = node.get_property("page_content")
280292
if node_text is None:
281293
return self.property_name, []
282-
result = await self.prompt.generate(
283-
self.llm,
284-
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_entities),
285-
)
286-
return self.property_name, result.entities
294+
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
295+
entities = []
296+
for chunk in chunks:
297+
result = await self.prompt.generate(
298+
self.llm,
299+
data=TextWithExtractionLimit(text=chunk, max_num=self.max_num_entities),
300+
)
301+
entities.extend(result.entities)
302+
return self.property_name, entities
287303

288304

289305
class TopicDescription(BaseModel):
@@ -328,7 +344,8 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
328344
node_text = node.get_property("page_content")
329345
if node_text is None:
330346
return self.property_name, None
331-
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
347+
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
348+
result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0]))
332349
return self.property_name, result.description
333350

334351

@@ -383,8 +400,13 @@ async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
383400
node_text = node.get_property("page_content")
384401
if node_text is None:
385402
return self.property_name, []
386-
result = await self.prompt.generate(
387-
self.llm,
388-
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_themes),
389-
)
390-
return self.property_name, result.output
403+
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
404+
themes = []
405+
for chunk in chunks:
406+
result = await self.prompt.generate(
407+
self.llm,
408+
data=TextWithExtractionLimit(text=chunk, max_num=self.max_num_themes),
409+
)
410+
themes.extend(result.output)
411+
412+
return self.property_name, themes

0 commit comments

Comments
 (0)