Skip to content

Commit bbe4b4e

Browse files
committed
feat - for global_search, add the support for tracking max_token usage / use it to stuff as many reports in one LLM call / sort the keypoints by score
1 parent 033762f commit bbe4b4e

File tree

4 files changed

+96
-20
lines changed

4 files changed

+96
-20
lines changed

examples/simple-app/app/query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def global_search(
7272
community_level=cast(CommunityLevel, level),
7373
weight_calculator=CommunityWeightCalculator(),
7474
artifacts=artifacts,
75+
token_counter=TiktokenCounter(),
7576
)
7677

7778
kp_generator = KeyPointsGenerator(
@@ -83,7 +84,9 @@ def global_search(
8384
kp_aggregator = KeyPointsAggregator(
8485
llm=make_llm_instance(llm_type, llm_model, cache_dir),
8586
prompt_builder=KeyPointsAggregatorPromptBuilder(),
86-
context_builder=KeyPointsContextBuilder(),
87+
context_builder=KeyPointsContextBuilder(
88+
token_counter=TiktokenCounter(),
89+
),
8790
)
8891

8992
global_search = GlobalSearch(

src/langchain_graphrag/query/global_search/key_points_aggregator/context_builder.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import logging
2+
13
from langchain_core.documents import Document
24

35
from langchain_graphrag.query.global_search.key_points_generator.utils import (
46
KeyPointsResult,
57
)
8+
from langchain_graphrag.utils.token_counter import TokenCounter
69

710
_REPORT_TEMPLATE = """
811
--- {analyst} ---
@@ -13,17 +16,48 @@
1316
1417
"""
1518

19+
_LOGGER = logging.getLogger(__name__)
20+
1621

1722
class KeyPointsContextBuilder:
23+
def __init__(
24+
self,
25+
token_counter: TokenCounter,
26+
max_tokens: int = 8000,
27+
):
28+
self._token_counter = token_counter
29+
self._max_tokens = max_tokens
30+
1831
def __call__(self, key_points: dict[str, KeyPointsResult]) -> list[Document]:
1932
documents: list[Document] = []
33+
total_tokens = 0
34+
max_token_limit_reached = False
2035
for k, v in key_points.items():
36+
if max_token_limit_reached:
37+
break
2138
for p in v.points:
2239
report = _REPORT_TEMPLATE.format(
2340
analyst=k,
2441
score=p.score,
2542
content=p.description,
2643
)
27-
documents.append(Document(page_content=report))
44+
report_token = self._token_counter.count_tokens(report)
45+
if total_tokens + report_token > self._max_tokens:
46+
_LOGGER.warning("Reached max tokens for key points aggregation ...")
47+
max_token_limit_reached = True
48+
break
49+
total_tokens += report_token
50+
documents.append(
51+
Document(
52+
page_content=report,
53+
metadata={"score": p.score, "analyst": k},
54+
)
55+
)
2856

29-
return documents
57+
# we now sort the documents based on the
58+
# importance score of the key points
59+
return sorted(
60+
documents,
61+
key=lambda x: x.metadata["score"],
62+
reverse=True,
63+
)

src/langchain_graphrag/query/global_search/key_points_generator/context_builder.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from langchain_core.documents import Document
24

35
from langchain_graphrag.indexing.artifacts import IndexerArtifacts
@@ -6,6 +8,7 @@
68
CommunityWeightCalculator,
79
)
810
from langchain_graphrag.types.graphs.community import CommunityId, CommunityLevel
11+
from langchain_graphrag.utils.token_counter import TokenCounter
912

1013
_REPORT_TEMPLATE = """
1114
--- Report {report_id} ---
@@ -19,17 +22,23 @@
1922
2023
"""
2124

25+
_LOGGER = logging.getLogger(__name__)
26+
2227

2328
class CommunityReportContextBuilder:
2429
def __init__(
2530
self,
2631
community_level: CommunityLevel,
2732
weight_calculator: CommunityWeightCalculator,
2833
artifacts: IndexerArtifacts,
34+
token_counter: TokenCounter,
35+
max_tokens: int = 8000,
2936
):
3037
self._community_level = community_level
3138
self._weight_calculator = weight_calculator
3239
self._artifacts = artifacts
40+
self._token_counter = token_counter
41+
self._max_tokens = max_tokens
3342

3443
def _filter_communities(self) -> list[CommunityReport]:
3544
df_entities = self._artifacts.entities
@@ -61,17 +70,37 @@ def __call__(self) -> list[Document]:
6170
reports = self._filter_communities()
6271

6372
documents: list[Document] = []
73+
report_str_accumulated: list[str] = []
74+
token_count = 0
6475
for report in reports:
65-
documents.append( # noqa: PERF401
66-
Document(
67-
page_content=_REPORT_TEMPLATE.format(
68-
report_id=report.id,
69-
title=report.title,
70-
weight=report.weight,
71-
rank=report.rank,
72-
content=report.content,
76+
# we would try to combine multiple
77+
# reports into a single document
78+
# as long as we do not exceed the token limit
79+
80+
report_str = _REPORT_TEMPLATE.format(
81+
report_id=report.id,
82+
title=report.title,
83+
weight=report.weight,
84+
rank=report.rank,
85+
content=report.content,
86+
)
87+
88+
report_str_token_count = self._token_counter.count_tokens(report_str)
89+
90+
if token_count + report_str_token_count > self._max_tokens:
91+
_LOGGER.debug("Reached max tokens for a community report call ...")
92+
# we cut a new document here
93+
documents.append(
94+
Document(
95+
page_content="\n\n".join(report_str_accumulated),
96+
metadata={"token_count": token_count},
7397
)
7498
)
75-
)
99+
# reset the token count and the accumulated string
100+
token_count = 0
101+
report_str_accumulated = []
102+
else:
103+
token_count += report_str_token_count
104+
report_str_accumulated.append(report_str)
76105

77106
return documents

src/langchain_graphrag/query/global_search/search.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import logging
12
from typing import Iterator
23

3-
from langchain_core.runnables import Runnable, RunnablePassthrough
4-
54
from .key_points_aggregator import KeyPointsAggregator
65
from .key_points_generator import KeyPointsGenerator
6+
from .key_points_generator.utils import (
7+
KeyPointsResult,
8+
)
9+
10+
_LOGGER = logging.getLogger(__name__)
711

812

913
class GlobalSearch:
@@ -15,21 +19,27 @@ def __init__(
1519
self._kp_generator = kp_generator
1620
self._kp_aggregator = kp_aggregator
1721

18-
def invoke(self, query: str) -> str:
22+
def _get_key_points(self, query: str) -> dict[str, KeyPointsResult]:
1923
generation_chain = self._kp_generator()
20-
aggregation_chain = self._kp_aggregator()
21-
2224
response = generation_chain.invoke(query)
2325

26+
if _LOGGER.getEffectiveLevel() == logging.INFO:
27+
for k, v in response.items():
28+
_LOGGER.info(f"{k} - {len(v.points)}")
29+
30+
return response
31+
32+
def invoke(self, query: str) -> str:
33+
aggregation_chain = self._kp_aggregator()
34+
response = self._get_key_points(query)
35+
2436
return aggregation_chain.invoke(
2537
input=dict(report_data=response, global_query=query)
2638
)
2739

2840
def stream(self, query: str) -> Iterator:
29-
generation_chain = self._kp_generator()
3041
aggregation_chain = self._kp_aggregator()
31-
32-
response = generation_chain.invoke(query)
42+
response = self._get_key_points(query)
3343

3444
return aggregation_chain.stream(
3545
input=dict(report_data=response, global_query=query)

0 commit comments

Comments
 (0)