Skip to content

Commit 76d6f50

Browse files
committed
Push tokenizer throughout
1 parent 2cb6dc3 commit 76d6f50

File tree

16 files changed

+105
-109
lines changed

16 files changed

+105
-109
lines changed

graphrag/index/operations/summarize_communities/build_mixed_context.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
99
sort_context,
1010
)
11-
from graphrag.query.llm.text_utils import num_tokens
11+
from graphrag.tokenizer.tokenizer import Tokenizer
1212

1313

14-
def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
14+
def build_mixed_context(
15+
context: list[dict], tokenizer: Tokenizer, max_context_tokens: int
16+
) -> str:
1517
"""
1618
Build parent context by concatenating all sub-communities' contexts.
1719
@@ -45,9 +47,10 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
4547
remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT])
4648
new_context_string = sort_context(
4749
local_context=remaining_local_context + final_local_contexts,
50+
tokenizer=tokenizer,
4851
sub_community_reports=substitute_reports,
4952
)
50-
if num_tokens(new_context_string) <= max_context_tokens:
53+
if tokenizer.num_tokens(new_context_string) <= max_context_tokens:
5154
exceeded_limit = False
5255
context_string = new_context_string
5356
break
@@ -63,7 +66,7 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
6366
new_context_string = pd.DataFrame(substitute_reports).to_csv(
6467
index=False, sep=","
6568
)
66-
if num_tokens(new_context_string) > max_context_tokens:
69+
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
6770
break
6871

6972
context_string = new_context_string

graphrag/index/operations/summarize_communities/graph_context/context_builder.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
where_column_equals,
3131
)
3232
from graphrag.logger.progress import progress_iterable
33-
from graphrag.query.llm.text_utils import num_tokens
33+
from graphrag.tokenizer.tokenizer import Tokenizer
3434

3535
logger = logging.getLogger(__name__)
3636

@@ -39,6 +39,7 @@ def build_local_context(
3939
nodes,
4040
edges,
4141
claims,
42+
tokenizer: Tokenizer,
4243
callbacks: WorkflowCallbacks,
4344
max_context_tokens: int = 16_000,
4445
):
@@ -49,7 +50,7 @@ def build_local_context(
4950

5051
for level in progress_iterable(levels, callbacks.progress, len(levels)):
5152
communities_at_level_df = _prepare_reports_at_level(
52-
nodes, edges, claims, level, max_context_tokens
53+
nodes, edges, claims, tokenizer, level, max_context_tokens
5354
)
5455

5556
communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level
@@ -63,6 +64,7 @@ def _prepare_reports_at_level(
6364
node_df: pd.DataFrame,
6465
edge_df: pd.DataFrame,
6566
claim_df: pd.DataFrame | None,
67+
tokenizer: Tokenizer,
6668
level: int,
6769
max_context_tokens: int = 16_000,
6870
) -> pd.DataFrame:
@@ -181,6 +183,7 @@ def _prepare_reports_at_level(
181183
# Generate community-level context strings using vectorized batch processing
182184
return parallel_sort_context_batch(
183185
community_df,
186+
tokenizer=tokenizer,
184187
max_context_tokens=max_context_tokens,
185188
)
186189

@@ -189,6 +192,7 @@ def build_level_context(
189192
report_df: pd.DataFrame | None,
190193
community_hierarchy_df: pd.DataFrame,
191194
local_context_df: pd.DataFrame,
195+
tokenizer: Tokenizer,
192196
level: int,
193197
max_context_tokens: int,
194198
) -> pd.DataFrame:
@@ -219,11 +223,11 @@ def build_level_context(
219223

220224
if report_df is None or report_df.empty:
221225
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
222-
invalid_context_df, max_context_tokens
226+
invalid_context_df, tokenizer, max_context_tokens
223227
)
224228
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
225229
:, schemas.CONTEXT_STRING
226-
].map(num_tokens)
230+
].map(tokenizer.num_tokens)
227231
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = False
228232
return union(valid_context_df, invalid_context_df)
229233

@@ -237,18 +241,21 @@ def build_level_context(
237241
invalid_context_df,
238242
sub_context_df,
239243
community_hierarchy_df,
244+
tokenizer,
240245
max_context_tokens,
241246
)
242247

243248
# handle any remaining invalid records that can't be subsituted with sub-community reports
244249
# this should be rare, but if it happens, we will just trim the local context to fit the limit
245250
remaining_df = _antijoin_reports(invalid_context_df, community_df)
246251
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
247-
remaining_df, max_context_tokens
252+
remaining_df, tokenizer, max_context_tokens
248253
)
249254

250255
result = union(valid_context_df, community_df, remaining_df)
251-
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
256+
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(
257+
tokenizer.num_tokens
258+
)
252259

253260
result[schemas.CONTEXT_EXCEED_FLAG] = False
254261
return result
@@ -269,19 +276,29 @@ def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
269276
return antijoin(df, reports, schemas.COMMUNITY_ID)
270277

271278

272-
def _sort_and_trim_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
279+
def _sort_and_trim_context(
280+
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
281+
) -> pd.Series:
273282
"""Sort and trim context to fit the limit."""
274283
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
275284
return transform_series(
276-
series, lambda x: sort_context(x, max_context_tokens=max_context_tokens)
285+
series,
286+
lambda x: sort_context(
287+
x, tokenizer=tokenizer, max_context_tokens=max_context_tokens
288+
),
277289
)
278290

279291

280-
def _build_mixed_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
292+
def _build_mixed_context(
293+
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
294+
) -> pd.Series:
281295
"""Sort and trim context to fit the limit."""
282296
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
283297
return transform_series(
284-
series, lambda x: build_mixed_context(x, max_context_tokens=max_context_tokens)
298+
series,
299+
lambda x: build_mixed_context(
300+
x, tokenizer, max_context_tokens=max_context_tokens
301+
),
285302
)
286303

287304

@@ -303,6 +320,7 @@ def _get_community_df(
303320
invalid_context_df: pd.DataFrame,
304321
sub_context_df: pd.DataFrame,
305322
community_hierarchy_df: pd.DataFrame,
323+
tokenizer: Tokenizer,
306324
max_context_tokens: int,
307325
) -> pd.DataFrame:
308326
"""Get community context for each community."""
@@ -338,7 +356,7 @@ def _get_community_df(
338356
.reset_index()
339357
)
340358
community_df[schemas.CONTEXT_STRING] = _build_mixed_context(
341-
community_df, max_context_tokens
359+
community_df, tokenizer, max_context_tokens
342360
)
343361
community_df[schemas.COMMUNITY_LEVEL] = level
344362
return community_df

graphrag/index/operations/summarize_communities/graph_context/sort_context.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import pandas as pd
66

77
import graphrag.data_model.schemas as schemas
8-
from graphrag.query.llm.text_utils import num_tokens
8+
from graphrag.tokenizer.tokenizer import Tokenizer
99

1010

1111
def sort_context(
1212
local_context: list[dict],
13+
tokenizer: Tokenizer,
1314
sub_community_reports: list[dict] | None = None,
1415
max_context_tokens: int | None = None,
1516
node_name_column: str = schemas.TITLE,
@@ -112,7 +113,10 @@ def _get_context_string(
112113
new_context_string = _get_context_string(
113114
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
114115
)
115-
if max_context_tokens and num_tokens(new_context_string) > max_context_tokens:
116+
if (
117+
max_context_tokens
118+
and tokenizer.num_tokens(new_context_string) > max_context_tokens
119+
):
116120
break
117121
context_string = new_context_string
118122

@@ -122,7 +126,9 @@ def _get_context_string(
122126
)
123127

124128

125-
def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False):
129+
def parallel_sort_context_batch(
130+
community_df, tokenizer: Tokenizer, max_context_tokens, parallel=False
131+
):
126132
"""Calculate context using parallelization if enabled."""
127133
if parallel:
128134
# Use ThreadPoolExecutor for parallel execution
@@ -131,7 +137,9 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
131137
with ThreadPoolExecutor(max_workers=None) as executor:
132138
context_strings = list(
133139
executor.map(
134-
lambda x: sort_context(x, max_context_tokens=max_context_tokens),
140+
lambda x: sort_context(
141+
x, tokenizer, max_context_tokens=max_context_tokens
142+
),
135143
community_df[schemas.ALL_CONTEXT],
136144
)
137145
)
@@ -141,13 +149,13 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
141149
# Assign context strings directly to the DataFrame
142150
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
143151
lambda context_list: sort_context(
144-
context_list, max_context_tokens=max_context_tokens
152+
context_list, tokenizer, max_context_tokens=max_context_tokens
145153
)
146154
)
147155

148156
# Calculate other columns
149157
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
150-
num_tokens
158+
tokenizer.num_tokens
151159
)
152160
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
153161
community_df[schemas.CONTEXT_SIZE] > max_context_tokens

graphrag/index/operations/summarize_communities/summarize_communities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from graphrag.index.utils.derive_from_rows import derive_from_rows
2525
from graphrag.logger.progress import progress_ticker
26+
from graphrag.tokenizer.tokenizer import Tokenizer
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -35,6 +36,7 @@ async def summarize_communities(
3536
callbacks: WorkflowCallbacks,
3637
cache: PipelineCache,
3738
strategy: dict,
39+
tokenizer: Tokenizer,
3840
max_input_length: int,
3941
async_mode: AsyncType = AsyncType.AsyncIO,
4042
num_threads: int = 4,
@@ -44,7 +46,6 @@ async def summarize_communities(
4446
tick = progress_ticker(callbacks.progress, len(local_contexts))
4547
strategy_exec = load_strategy(strategy["type"])
4648
strategy_config = {**strategy}
47-
4849
community_hierarchy = (
4950
communities.explode("children")
5051
.rename({"children": "sub_community"}, axis=1)
@@ -60,6 +61,7 @@ async def summarize_communities(
6061
community_hierarchy_df=community_hierarchy,
6162
local_context_df=local_contexts,
6263
level=level,
64+
tokenizer=tokenizer,
6365
max_context_tokens=max_input_length,
6466
)
6567
level_contexts.append(level_context)

graphrag/index/operations/summarize_communities/text_unit_context/context_builder.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from graphrag.index.operations.summarize_communities.text_unit_context.sort_context import (
1919
sort_context,
2020
)
21-
from graphrag.query.llm.text_utils import num_tokens
21+
from graphrag.tokenizer.tokenizer import Tokenizer
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -27,6 +27,7 @@ def build_local_context(
2727
community_membership_df: pd.DataFrame,
2828
text_units_df: pd.DataFrame,
2929
node_df: pd.DataFrame,
30+
tokenizer: Tokenizer,
3031
max_context_tokens: int = 16000,
3132
) -> pd.DataFrame:
3233
"""
@@ -69,10 +70,10 @@ def build_local_context(
6970
.reset_index()
7071
)
7172
context_df[schemas.CONTEXT_STRING] = context_df[schemas.ALL_CONTEXT].apply(
72-
lambda x: sort_context(x)
73+
lambda x: sort_context(x, tokenizer)
7374
)
7475
context_df[schemas.CONTEXT_SIZE] = context_df[schemas.CONTEXT_STRING].apply(
75-
lambda x: num_tokens(x)
76+
lambda x: tokenizer.num_tokens(x)
7677
)
7778
context_df[schemas.CONTEXT_EXCEED_FLAG] = context_df[schemas.CONTEXT_SIZE].apply(
7879
lambda x: x > max_context_tokens
@@ -86,6 +87,7 @@ def build_level_context(
8687
community_hierarchy_df: pd.DataFrame,
8788
local_context_df: pd.DataFrame,
8889
level: int,
90+
tokenizer: Tokenizer,
8991
max_context_tokens: int = 16000,
9092
) -> pd.DataFrame:
9193
"""
@@ -116,10 +118,12 @@ def build_level_context(
116118

117119
invalid_context_df.loc[:, [schemas.CONTEXT_STRING]] = invalid_context_df[
118120
schemas.ALL_CONTEXT
119-
].apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
121+
].apply(
122+
lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens)
123+
)
120124
invalid_context_df.loc[:, [schemas.CONTEXT_SIZE]] = invalid_context_df[
121125
schemas.CONTEXT_STRING
122-
].apply(lambda x: num_tokens(x))
126+
].apply(lambda x: tokenizer.num_tokens(x))
123127
invalid_context_df.loc[:, [schemas.CONTEXT_EXCEED_FLAG]] = False
124128

125129
return pd.concat([valid_context_df, invalid_context_df])
@@ -199,10 +203,10 @@ def build_level_context(
199203
.reset_index()
200204
)
201205
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
202-
lambda x: build_mixed_context(x, max_context_tokens)
206+
lambda x: build_mixed_context(x, tokenizer, max_context_tokens)
203207
)
204208
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
205-
lambda x: num_tokens(x)
209+
lambda x: tokenizer.num_tokens(x)
206210
)
207211
community_df[schemas.CONTEXT_EXCEED_FLAG] = False
208212
community_df[schemas.COMMUNITY_LEVEL] = level
@@ -220,10 +224,10 @@ def build_level_context(
220224
)
221225
remaining_df[schemas.CONTEXT_STRING] = cast(
222226
"pd.DataFrame", remaining_df[schemas.ALL_CONTEXT]
223-
).apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
227+
).apply(lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens))
224228
remaining_df[schemas.CONTEXT_SIZE] = cast(
225229
"pd.DataFrame", remaining_df[schemas.CONTEXT_STRING]
226-
).apply(lambda x: num_tokens(x))
230+
).apply(lambda x: tokenizer.num_tokens(x))
227231
remaining_df[schemas.CONTEXT_EXCEED_FLAG] = False
228232

229233
return cast(

graphrag/index/operations/summarize_communities/text_unit_context/sort_context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99

1010
import graphrag.data_model.schemas as schemas
11-
from graphrag.query.llm.text_utils import num_tokens
11+
from graphrag.tokenizer.tokenizer import Tokenizer
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -57,6 +57,7 @@ def get_context_string(
5757

5858
def sort_context(
5959
local_context: list[dict],
60+
tokenizer: Tokenizer,
6061
sub_community_reports: list[dict] | None = None,
6162
max_context_tokens: int | None = None,
6263
) -> str:
@@ -73,7 +74,7 @@ def sort_context(
7374
new_context_string = get_context_string(
7475
current_text_units, sub_community_reports
7576
)
76-
if num_tokens(new_context_string) > max_context_tokens:
77+
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
7778
break
7879

7980
context_string = new_context_string

0 commit comments

Comments
 (0)