33
44"""All the steps to transform base text_units."""
55
6- from typing import cast
6+ import json
7+ from typing import Any , cast
78
89import pandas as pd
910
1011from graphrag .callbacks .workflow_callbacks import WorkflowCallbacks
1112from graphrag .config .models .chunking_config import ChunkStrategyType
1213from graphrag .index .operations .chunk_text .chunk_text import chunk_text
14+ from graphrag .index .operations .chunk_text .strategies import get_encoding_fn
1315from graphrag .index .utils .hashing import gen_sha512_hash
1416from graphrag .logger .progress import Progress
1517
@@ -22,6 +24,8 @@ def create_base_text_units(
2224 overlap : int ,
2325 encoding_model : str ,
2426 strategy : ChunkStrategyType ,
27+ prepend_metadata : bool = False ,
28+ chunk_size_includes_metadata : bool = False ,
2529) -> pd .DataFrame :
2630 """All the steps to transform base text_units."""
2731 sort = documents .sort_values (by = ["id" ], ascending = [True ])
@@ -32,25 +36,66 @@ def create_base_text_units(
3236
3337 callbacks .progress (Progress (percent = 0 ))
3438
39+ agg_dict = {"text_with_ids" : list }
40+ if "metadata" in documents :
41+ agg_dict ["metadata" ] = "first" # type: ignore
42+
3543 aggregated = (
3644 (
3745 sort .groupby (group_by_columns , sort = False )
3846 if len (group_by_columns ) > 0
3947 else sort .groupby (lambda _x : True )
4048 )
41- .agg (texts = ( "text_with_ids" , list ) )
49+ .agg (agg_dict )
4250 .reset_index ()
4351 )
52+ aggregated .rename (columns = {"text_with_ids" : "texts" }, inplace = True )
4453
45- aggregated ["chunks" ] = chunk_text (
46- aggregated ,
47- column = "texts" ,
48- size = size ,
49- overlap = overlap ,
50- encoding_model = encoding_model ,
51- strategy = strategy ,
52- callbacks = callbacks ,
53- )
54+ def chunker (row : dict [str , Any ]) -> Any :
55+ line_delimiter = ".\n "
56+ metadata_str = ""
57+ metadata_tokens = 0
58+
59+ if prepend_metadata and "metadata" in row :
60+ metadata = row ["metadata" ]
61+ if isinstance (metadata , str ):
62+ metadata = json .loads (metadata )
63+ if isinstance (metadata , dict ):
64+ metadata_str = (
65+ line_delimiter .join (f"{ k } : { v } " for k , v in metadata .items ())
66+ + line_delimiter
67+ )
68+
69+ if chunk_size_includes_metadata :
70+ encode , _ = get_encoding_fn (encoding_model )
71+ metadata_tokens = len (encode (metadata_str ))
72+ if metadata_tokens >= size :
73+ message = "Metadata tokens exceeds the maximum tokens per chunk. Please increase the tokens per chunk."
74+ raise ValueError (message )
75+
76+ chunked = chunk_text (
77+ pd .DataFrame ([row ]).reset_index (drop = True ),
78+ column = "texts" ,
79+ size = size - metadata_tokens ,
80+ overlap = overlap ,
81+ encoding_model = encoding_model ,
82+ strategy = strategy ,
83+ callbacks = callbacks ,
84+ )[0 ]
85+
86+ if prepend_metadata :
87+ for index , chunk in enumerate (chunked ):
88+ if isinstance (chunk , str ):
89+ chunked [index ] = metadata_str + chunk
90+ else :
91+ chunked [index ] = (
92+ (chunk [0 ], metadata_str + chunk [1 ], chunk [2 ]) if chunk else None
93+ )
94+
95+ row ["chunks" ] = chunked
96+ return row
97+
98+ aggregated = aggregated .apply (lambda row : chunker (row ), axis = 1 )
5499
55100 aggregated = cast ("pd.DataFrame" , aggregated [[* group_by_columns , "chunks" ]])
56101 aggregated = aggregated .explode ("chunks" )
0 commit comments